diff --git a/b_asic/architecture.py b/b_asic/architecture.py index cdf90d6dd0eeb98b693bc80918b1d6010dfa5504..bdf1495293de6c6f8f12df755cbab1be5290c8ba 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -2,13 +2,70 @@ B-ASIC architecture classes. """ from collections import defaultdict -from typing import List, Optional, Set, cast +from typing import Dict, List, Optional, Set, Tuple, cast from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable from b_asic.resources import ProcessCollection -class ProcessingElement: +def _interconnect_dict() -> int: + # Needed as pickle does not support lambdas + return 0 + + +class Resource: + """ + Base class for resource. + + Parameters + ---------- + process_collection : ProcessCollection + The process collection containing processes to be mapped to resource. + entity_name : str, optional + The name of the resulting entity. + + """ + + def __init__( + self, process_collection: ProcessCollection, entity_name: Optional[str] = None + ): + if not len(process_collection): + raise ValueError("Do not create Resource with empty ProcessCollection") + self._collection = process_collection + self._entity_name = entity_name + + def __repr__(self): + return self._entity_name + + def __iter__(self): + return iter(self._collection) + + def set_entity_name(self, entity_name: str): + """ + Set entity name of resource. + + Parameters + ---------- + entity_name : str + The entity name. + """ + self._entity_name = entity_name + + def write_code(self, path: str) -> None: + """ + Write VHDL code for resource. + + Parameters + ---------- + path : str + Directory to write code in. + """ + if not self._entity_name: + raise ValueError("Entity name must be set") + raise NotImplementedError + + +class ProcessingElement(Resource): """ Create a processing element for a ProcessCollection with OperatorProcesses. @@ -23,10 +80,7 @@ class ProcessingElement: def __init__( self, process_collection: ProcessCollection, entity_name: Optional[str] = None ): - if not len(process_collection): - raise ValueError( - "Do not create ProcessingElement with empty ProcessCollection" - ) + super().__init__(process_collection=process_collection, entity_name=entity_name) if not all( isinstance(operator, OperatorProcess) for operator in process_collection.collection @@ -51,27 +105,8 @@ class ProcessingElement: def processes(self) -> Set[OperatorProcess]: return {cast(OperatorProcess, p) for p in self._collection} - def __repr__(self): - return self._entity_name or self._type_name - - def set_entity_name(self, entity_name: str): - self._entity_name = entity_name - - def write_code(self, path: str) -> None: - """ - Write VHDL code for processing element. - Parameters - ---------- - path : str - Directory to write code in. - """ - if not self._entity_name: - raise ValueError("Entity name must be set") - raise NotImplementedError - - -class Memory: +class Memory(Resource): """ Create a memory from a ProcessCollection with memory variables. @@ -91,8 +126,7 @@ class Memory: memory_type: str = "RAM", entity_name: Optional[str] = None, ): - if not len(process_collection): - raise ValueError("Do not create Memory with empty ProcessCollection") + super().__init__(process_collection=process_collection, entity_name=entity_name) if not all( isinstance(operator, (MemoryVariable, PlainMemoryVariable)) for operator in process_collection.collection @@ -105,31 +139,7 @@ class Memory: raise ValueError( f"memory_type must be 'RAM' or 'register', not {memory_type!r}" ) - self._collection = process_collection self._memory_type = memory_type - self._entity_name = entity_name - - def __iter__(self): - return iter(self._collection) - - def set_entity_name(self, entity_name: str): - self._entity_name = entity_name - - def __repr__(self): - return self._entity_name or self._memory_type - - def write_code(self, path: str) -> None: - """ - Write VHDL code for memory. - - Parameters - ---------- - path : str - Directory to write code in. - """ - if not self._entity_name: - raise ValueError("Entity name must be set") - raise NotImplementedError class Architecture: @@ -195,7 +205,8 @@ class Architecture: ] = self._operation_inport_to_resource[read_port] def validate_ports(self): - # Validate inputs and outputs of memory variables in all the memories in this architecture + # Validate inputs and outputs of memory variables in all the memories in this + # architecture memory_read_ports = set() memory_write_ports = set() for memory in self.memories: @@ -242,8 +253,22 @@ class Architecture: raise NotImplementedError def get_interconnects_for_memory(self, mem: Memory): - d_in = defaultdict(lambda: 0) - d_out = defaultdict(lambda: 0) + """ + Return a dictionary with interconnect information for a Memory. + + Parameters + ---------- + mem : :class:`Memory` + The memory to obtain information about. + + Returns + ------- + (dict, dict) + A dictionary with the ProcessingElements that are connected to the write and + read ports, respectively, with counts of the number of accesses. + """ + d_in = defaultdict(_interconnect_dict) + d_out = defaultdict(_interconnect_dict) for var in mem._collection: var = cast(MemoryVariable, var) d_in[self._operation_outport_to_resource[var.write_port]] += 1 @@ -251,10 +276,31 @@ class Architecture: d_out[self._operation_inport_to_resource[read_port]] += 1 return dict(d_in), dict(d_out) - def get_interconnects_for_pe(self, pe: ProcessingElement): + def get_interconnects_for_pe( + self, pe: ProcessingElement + ) -> Tuple[List[Dict[str, int]], List[Dict[str, int]]]: + """ + Return lists of dictionaries with interconnect information for a + ProcessingElement. + + Parameters + ---------- + pe : :class:`ProcessingElement` + The processing element to get information for. + + Returns + ------- + list + List of dictionaries indicating the sources for each inport and the + frequency of accesses. + list + List of dictionaries indicating the sources for each outport and the + frequency of accesses. + + """ ops = cast(List[OperatorProcess], list(pe._collection)) - d_in = [defaultdict(lambda: 0) for _ in ops[0].operation.inputs] - d_out = [defaultdict(lambda: 0) for _ in ops[0].operation.outputs] + d_in = [defaultdict(_interconnect_dict) for _ in ops[0].operation.inputs] + d_out = [defaultdict(_interconnect_dict) for _ in ops[0].operation.outputs] for var in pe._collection: var = cast(OperatorProcess, var) for i, input in enumerate(var.operation.inputs): @@ -264,6 +310,14 @@ class Architecture: return [dict(d) for d in d_in], [dict(d) for d in d_out] def set_entity_name(self, entity_name: str): + """ + Set entity name of architecture. + + Parameters + ---------- + entity_name : str + The entity name. + """ self._entity_name = entity_name @property diff --git a/test/test_architecture.py b/test/test_architecture.py index 3d5e59df09acc12a076903f367d9ae315acfdab3..ba9f910f6d590850f1626ce5e7352a8379d43153 100644 --- a/test/test_architecture.py +++ b/test/test_architecture.py @@ -22,7 +22,7 @@ def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Sched ProcessingElement(mvs) empty_collection = ProcessCollection(collection=set(), schedule_time=5) with pytest.raises( - ValueError, match="Do not create ProcessingElement with empty ProcessCollection" + ValueError, match="Do not create Resource with empty ProcessCollection" ): ProcessingElement(empty_collection) @@ -52,7 +52,7 @@ def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule): operations = schedule_direct_form_iir_lp_filter.get_operations() empty_collection = ProcessCollection(collection=set(), schedule_time=5) with pytest.raises( - ValueError, match="Do not create Memory with empty ProcessCollection" + ValueError, match="Do not create Resource with empty ProcessCollection" ): Memory(empty_collection) with pytest.raises(