Skip to content
Snippets Groups Projects
Commit 22fe6ad7 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Refactor resources

parent f2f92899
No related branches found
No related tags found
1 merge request!342Refactor resources
Pipeline #96614 passed
...@@ -2,13 +2,70 @@ ...@@ -2,13 +2,70 @@
B-ASIC architecture classes. B-ASIC architecture classes.
""" """
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Set, cast from typing import Dict, List, Optional, Set, Tuple, cast
from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable
from b_asic.resources import ProcessCollection from b_asic.resources import ProcessCollection
class ProcessingElement: def _interconnect_dict() -> int:
# Needed as pickle does not support lambdas
return 0
class Resource:
"""
Base class for resource.
Parameters
----------
process_collection : ProcessCollection
The process collection containing processes to be mapped to resource.
entity_name : str, optional
The name of the resulting entity.
"""
def __init__(
self, process_collection: ProcessCollection, entity_name: Optional[str] = None
):
if not len(process_collection):
raise ValueError("Do not create Resource with empty ProcessCollection")
self._collection = process_collection
self._entity_name = entity_name
def __repr__(self):
return self._entity_name
def __iter__(self):
return iter(self._collection)
def set_entity_name(self, entity_name: str):
"""
Set entity name of resource.
Parameters
----------
entity_name : str
The entity name.
"""
self._entity_name = entity_name
def write_code(self, path: str) -> None:
"""
Write VHDL code for resource.
Parameters
----------
path : str
Directory to write code in.
"""
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
class ProcessingElement(Resource):
""" """
Create a processing element for a ProcessCollection with OperatorProcesses. Create a processing element for a ProcessCollection with OperatorProcesses.
...@@ -23,10 +80,7 @@ class ProcessingElement: ...@@ -23,10 +80,7 @@ class ProcessingElement:
def __init__( def __init__(
self, process_collection: ProcessCollection, entity_name: Optional[str] = None self, process_collection: ProcessCollection, entity_name: Optional[str] = None
): ):
if not len(process_collection): super().__init__(process_collection=process_collection, entity_name=entity_name)
raise ValueError(
"Do not create ProcessingElement with empty ProcessCollection"
)
if not all( if not all(
isinstance(operator, OperatorProcess) isinstance(operator, OperatorProcess)
for operator in process_collection.collection for operator in process_collection.collection
...@@ -51,27 +105,8 @@ class ProcessingElement: ...@@ -51,27 +105,8 @@ class ProcessingElement:
def processes(self) -> Set[OperatorProcess]: def processes(self) -> Set[OperatorProcess]:
return {cast(OperatorProcess, p) for p in self._collection} return {cast(OperatorProcess, p) for p in self._collection}
def __repr__(self):
return self._entity_name or self._type_name
def set_entity_name(self, entity_name: str):
self._entity_name = entity_name
def write_code(self, path: str) -> None:
"""
Write VHDL code for processing element.
Parameters class Memory(Resource):
----------
path : str
Directory to write code in.
"""
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
class Memory:
""" """
Create a memory from a ProcessCollection with memory variables. Create a memory from a ProcessCollection with memory variables.
...@@ -91,8 +126,7 @@ class Memory: ...@@ -91,8 +126,7 @@ class Memory:
memory_type: str = "RAM", memory_type: str = "RAM",
entity_name: Optional[str] = None, entity_name: Optional[str] = None,
): ):
if not len(process_collection): super().__init__(process_collection=process_collection, entity_name=entity_name)
raise ValueError("Do not create Memory with empty ProcessCollection")
if not all( if not all(
isinstance(operator, (MemoryVariable, PlainMemoryVariable)) isinstance(operator, (MemoryVariable, PlainMemoryVariable))
for operator in process_collection.collection for operator in process_collection.collection
...@@ -105,31 +139,7 @@ class Memory: ...@@ -105,31 +139,7 @@ class Memory:
raise ValueError( raise ValueError(
f"memory_type must be 'RAM' or 'register', not {memory_type!r}" f"memory_type must be 'RAM' or 'register', not {memory_type!r}"
) )
self._collection = process_collection
self._memory_type = memory_type self._memory_type = memory_type
self._entity_name = entity_name
def __iter__(self):
return iter(self._collection)
def set_entity_name(self, entity_name: str):
self._entity_name = entity_name
def __repr__(self):
return self._entity_name or self._memory_type
def write_code(self, path: str) -> None:
"""
Write VHDL code for memory.
Parameters
----------
path : str
Directory to write code in.
"""
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
class Architecture: class Architecture:
...@@ -195,7 +205,8 @@ class Architecture: ...@@ -195,7 +205,8 @@ class Architecture:
] = self._operation_inport_to_resource[read_port] ] = self._operation_inport_to_resource[read_port]
def validate_ports(self): def validate_ports(self):
# Validate inputs and outputs of memory variables in all the memories in this architecture # Validate inputs and outputs of memory variables in all the memories in this
# architecture
memory_read_ports = set() memory_read_ports = set()
memory_write_ports = set() memory_write_ports = set()
for memory in self.memories: for memory in self.memories:
...@@ -242,8 +253,22 @@ class Architecture: ...@@ -242,8 +253,22 @@ class Architecture:
raise NotImplementedError raise NotImplementedError
def get_interconnects_for_memory(self, mem: Memory): def get_interconnects_for_memory(self, mem: Memory):
d_in = defaultdict(lambda: 0) """
d_out = defaultdict(lambda: 0) Return a dictionary with interconnect information for a Memory.
Parameters
----------
mem : :class:`Memory`
The memory to obtain information about.
Returns
-------
(dict, dict)
A dictionary with the ProcessingElements that are connected to the write and
read ports, respectively, with counts of the number of accesses.
"""
d_in = defaultdict(_interconnect_dict)
d_out = defaultdict(_interconnect_dict)
for var in mem._collection: for var in mem._collection:
var = cast(MemoryVariable, var) var = cast(MemoryVariable, var)
d_in[self._operation_outport_to_resource[var.write_port]] += 1 d_in[self._operation_outport_to_resource[var.write_port]] += 1
...@@ -251,10 +276,31 @@ class Architecture: ...@@ -251,10 +276,31 @@ class Architecture:
d_out[self._operation_inport_to_resource[read_port]] += 1 d_out[self._operation_inport_to_resource[read_port]] += 1
return dict(d_in), dict(d_out) return dict(d_in), dict(d_out)
def get_interconnects_for_pe(self, pe: ProcessingElement): def get_interconnects_for_pe(
self, pe: ProcessingElement
) -> Tuple[List[Dict[str, int]], List[Dict[str, int]]]:
"""
Return lists of dictionaries with interconnect information for a
ProcessingElement.
Parameters
----------
pe : :class:`ProcessingElement`
The processing element to get information for.
Returns
-------
list
List of dictionaries indicating the sources for each inport and the
frequency of accesses.
list
List of dictionaries indicating the sources for each outport and the
frequency of accesses.
"""
ops = cast(List[OperatorProcess], list(pe._collection)) ops = cast(List[OperatorProcess], list(pe._collection))
d_in = [defaultdict(lambda: 0) for _ in ops[0].operation.inputs] d_in = [defaultdict(_interconnect_dict) for _ in ops[0].operation.inputs]
d_out = [defaultdict(lambda: 0) for _ in ops[0].operation.outputs] d_out = [defaultdict(_interconnect_dict) for _ in ops[0].operation.outputs]
for var in pe._collection: for var in pe._collection:
var = cast(OperatorProcess, var) var = cast(OperatorProcess, var)
for i, input in enumerate(var.operation.inputs): for i, input in enumerate(var.operation.inputs):
...@@ -264,6 +310,14 @@ class Architecture: ...@@ -264,6 +310,14 @@ class Architecture:
return [dict(d) for d in d_in], [dict(d) for d in d_out] return [dict(d) for d in d_in], [dict(d) for d in d_out]
def set_entity_name(self, entity_name: str): def set_entity_name(self, entity_name: str):
"""
Set entity name of architecture.
Parameters
----------
entity_name : str
The entity name.
"""
self._entity_name = entity_name self._entity_name = entity_name
@property @property
......
...@@ -22,7 +22,7 @@ def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Sched ...@@ -22,7 +22,7 @@ def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Sched
ProcessingElement(mvs) ProcessingElement(mvs)
empty_collection = ProcessCollection(collection=set(), schedule_time=5) empty_collection = ProcessCollection(collection=set(), schedule_time=5)
with pytest.raises( with pytest.raises(
ValueError, match="Do not create ProcessingElement with empty ProcessCollection" ValueError, match="Do not create Resource with empty ProcessCollection"
): ):
ProcessingElement(empty_collection) ProcessingElement(empty_collection)
...@@ -52,7 +52,7 @@ def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule): ...@@ -52,7 +52,7 @@ def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule):
operations = schedule_direct_form_iir_lp_filter.get_operations() operations = schedule_direct_form_iir_lp_filter.get_operations()
empty_collection = ProcessCollection(collection=set(), schedule_time=5) empty_collection = ProcessCollection(collection=set(), schedule_time=5)
with pytest.raises( with pytest.raises(
ValueError, match="Do not create Memory with empty ProcessCollection" ValueError, match="Do not create Resource with empty ProcessCollection"
): ):
Memory(empty_collection) Memory(empty_collection)
with pytest.raises( with pytest.raises(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment