Skip to content
Snippets Groups Projects
Commit 22fe6ad7 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Refactor resources

parent f2f92899
No related branches found
No related tags found
1 merge request!342Refactor resources
Pipeline #96614 passed
......@@ -2,13 +2,70 @@
B-ASIC architecture classes.
"""
from collections import defaultdict
from typing import List, Optional, Set, cast
from typing import Dict, List, Optional, Set, Tuple, cast
from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable
from b_asic.resources import ProcessCollection
class ProcessingElement:
def _interconnect_dict() -> int:
# Needed as pickle does not support lambdas
return 0
class Resource:
"""
Base class for resource.
Parameters
----------
process_collection : ProcessCollection
The process collection containing processes to be mapped to resource.
entity_name : str, optional
The name of the resulting entity.
"""
def __init__(
self, process_collection: ProcessCollection, entity_name: Optional[str] = None
):
if not len(process_collection):
raise ValueError("Do not create Resource with empty ProcessCollection")
self._collection = process_collection
self._entity_name = entity_name
def __repr__(self):
return self._entity_name
def __iter__(self):
return iter(self._collection)
def set_entity_name(self, entity_name: str):
"""
Set entity name of resource.
Parameters
----------
entity_name : str
The entity name.
"""
self._entity_name = entity_name
def write_code(self, path: str) -> None:
"""
Write VHDL code for resource.
Parameters
----------
path : str
Directory to write code in.
"""
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
class ProcessingElement(Resource):
"""
Create a processing element for a ProcessCollection with OperatorProcesses.
......@@ -23,10 +80,7 @@ class ProcessingElement:
def __init__(
self, process_collection: ProcessCollection, entity_name: Optional[str] = None
):
if not len(process_collection):
raise ValueError(
"Do not create ProcessingElement with empty ProcessCollection"
)
super().__init__(process_collection=process_collection, entity_name=entity_name)
if not all(
isinstance(operator, OperatorProcess)
for operator in process_collection.collection
......@@ -51,27 +105,8 @@ class ProcessingElement:
def processes(self) -> Set[OperatorProcess]:
return {cast(OperatorProcess, p) for p in self._collection}
def __repr__(self):
return self._entity_name or self._type_name
def set_entity_name(self, entity_name: str):
self._entity_name = entity_name
def write_code(self, path: str) -> None:
"""
Write VHDL code for processing element.
Parameters
----------
path : str
Directory to write code in.
"""
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
class Memory:
class Memory(Resource):
"""
Create a memory from a ProcessCollection with memory variables.
......@@ -91,8 +126,7 @@ class Memory:
memory_type: str = "RAM",
entity_name: Optional[str] = None,
):
if not len(process_collection):
raise ValueError("Do not create Memory with empty ProcessCollection")
super().__init__(process_collection=process_collection, entity_name=entity_name)
if not all(
isinstance(operator, (MemoryVariable, PlainMemoryVariable))
for operator in process_collection.collection
......@@ -105,31 +139,7 @@ class Memory:
raise ValueError(
f"memory_type must be 'RAM' or 'register', not {memory_type!r}"
)
self._collection = process_collection
self._memory_type = memory_type
self._entity_name = entity_name
def __iter__(self):
return iter(self._collection)
def set_entity_name(self, entity_name: str):
self._entity_name = entity_name
def __repr__(self):
return self._entity_name or self._memory_type
def write_code(self, path: str) -> None:
"""
Write VHDL code for memory.
Parameters
----------
path : str
Directory to write code in.
"""
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
class Architecture:
......@@ -195,7 +205,8 @@ class Architecture:
] = self._operation_inport_to_resource[read_port]
def validate_ports(self):
# Validate inputs and outputs of memory variables in all the memories in this architecture
# Validate inputs and outputs of memory variables in all the memories in this
# architecture
memory_read_ports = set()
memory_write_ports = set()
for memory in self.memories:
......@@ -242,8 +253,22 @@ class Architecture:
raise NotImplementedError
def get_interconnects_for_memory(self, mem: Memory):
d_in = defaultdict(lambda: 0)
d_out = defaultdict(lambda: 0)
"""
Return a dictionary with interconnect information for a Memory.
Parameters
----------
mem : :class:`Memory`
The memory to obtain information about.
Returns
-------
(dict, dict)
A dictionary with the ProcessingElements that are connected to the write and
read ports, respectively, with counts of the number of accesses.
"""
d_in = defaultdict(_interconnect_dict)
d_out = defaultdict(_interconnect_dict)
for var in mem._collection:
var = cast(MemoryVariable, var)
d_in[self._operation_outport_to_resource[var.write_port]] += 1
......@@ -251,10 +276,31 @@ class Architecture:
d_out[self._operation_inport_to_resource[read_port]] += 1
return dict(d_in), dict(d_out)
def get_interconnects_for_pe(self, pe: ProcessingElement):
def get_interconnects_for_pe(
self, pe: ProcessingElement
) -> Tuple[List[Dict[str, int]], List[Dict[str, int]]]:
"""
Return lists of dictionaries with interconnect information for a
ProcessingElement.
Parameters
----------
pe : :class:`ProcessingElement`
The processing element to get information for.
Returns
-------
list
List of dictionaries indicating the sources for each inport and the
frequency of accesses.
list
List of dictionaries indicating the sources for each outport and the
frequency of accesses.
"""
ops = cast(List[OperatorProcess], list(pe._collection))
d_in = [defaultdict(lambda: 0) for _ in ops[0].operation.inputs]
d_out = [defaultdict(lambda: 0) for _ in ops[0].operation.outputs]
d_in = [defaultdict(_interconnect_dict) for _ in ops[0].operation.inputs]
d_out = [defaultdict(_interconnect_dict) for _ in ops[0].operation.outputs]
for var in pe._collection:
var = cast(OperatorProcess, var)
for i, input in enumerate(var.operation.inputs):
......@@ -264,6 +310,14 @@ class Architecture:
return [dict(d) for d in d_in], [dict(d) for d in d_out]
def set_entity_name(self, entity_name: str):
"""
Set entity name of architecture.
Parameters
----------
entity_name : str
The entity name.
"""
self._entity_name = entity_name
@property
......
......@@ -22,7 +22,7 @@ def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Sched
ProcessingElement(mvs)
empty_collection = ProcessCollection(collection=set(), schedule_time=5)
with pytest.raises(
ValueError, match="Do not create ProcessingElement with empty ProcessCollection"
ValueError, match="Do not create Resource with empty ProcessCollection"
):
ProcessingElement(empty_collection)
......@@ -52,7 +52,7 @@ def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule):
operations = schedule_direct_form_iir_lp_filter.get_operations()
empty_collection = ProcessCollection(collection=set(), schedule_time=5)
with pytest.raises(
ValueError, match="Do not create Memory with empty ProcessCollection"
ValueError, match="Do not create Resource with empty ProcessCollection"
):
Memory(empty_collection)
with pytest.raises(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment