diff --git a/b_asic/codegen/vhdl/entity.py b/b_asic/codegen/vhdl/entity.py index 60c247500a79fe213e40a269485d4ae143e8caeb..8e2328e86318a353fd370c0d180b7db41cf55e5c 100644 --- a/b_asic/codegen/vhdl/entity.py +++ b/b_asic/codegen/vhdl/entity.py @@ -55,7 +55,7 @@ def memory_based_storage( # Write the input port specification f.write(f'{2*VHDL_TAB}-- Memory port I/O\n') read_ports: set[Port] = set( - sum((mv.read_ports for mv in collection), ()) + read_port for mv in collection for read_port in mv.read_ports ) # type: ignore for idx, read_port in enumerate(read_ports): port_name = read_port if isinstance(read_port, int) else read_port.name diff --git a/b_asic/process.py b/b_asic/process.py index 626d04d4f30fa56b9e414c8fefbb90dc5f6c5e47..5a1cc620a2a9d7916f61c0458befd114fcb32efd 100644 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -1,6 +1,6 @@ """B-ASIC classes representing resource usage.""" -from typing import Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort @@ -105,7 +105,149 @@ class OperatorProcess(Process): return f"OperatorProcess({self.start_time}, {self.operation}, {self.name!r})" -class MemoryVariable(Process): +class MemoryProcess(Process): + """ + Intermediate class (abstract) for memory processes. + + Different from regular :class:`Processe` objects, :class:`MemoryProcess` objects + can contain multiple read accesses and can be split into two new + :class:`MemoryProcess` objects based on these read times. + + Parameters + ---------- + write_time : int + Start time of process. + life_times : list of int + List of ints representing times after ``start_time`` this process is accessed. + name : str, default="" + Name of the process. + """ + + def __init__( + self, + write_time: int, + life_times: List[int], + name: str = "", + ): + pass + self._life_times = life_times + super().__init__( + start_time=write_time, + execution_time=max(self._life_times), + name=name, + ) + + @property + def read_times(self) -> List[int]: + return list(self.start_time + read for read in self._life_times) + + @property + def life_times(self) -> List[int]: + return self._life_times + + @property + def reads(self) -> Dict[Any, int]: + raise NotImplementedError("MultiReadProcess should be derived from") + + @property + def read_ports(self) -> List[Any]: + raise NotImplementedError("MultiReadProcess should be derived from") + + @property + def write_port(self) -> Any: + raise NotImplementedError("MultiReadProcess should be derived from") + + def split_on_length( + self, + length: int = 0, + ) -> Tuple[Optional["MemoryProcess"], Optional["MemoryProcess"]]: + """ + Split this :class:`MemoryProcess` into two new :class:`MemoryProcess` objects, + based on lifetimes of the read accesses. + + Parameters + ---------- + length : int, default: 0 + The life time length to split on. Length is inclusive for the smaller + process. + + Returns + ------- + Two-tuple where the first element is a :class:`MemoryProcess` consisting + of reads with read times smaller than or equal to ``length`` (or None if no such + reads exists), and vice-versa for the other tuple element. + """ + reads = self.reads + short_reads = {k: v for k, v in filter(lambda t: t[1] <= length, reads.items())} + long_reads = {k: v for k, v in filter(lambda t: t[1] > length, reads.items())} + short_process = None + long_process = None + if short_reads: + # Create a new Process of type self (which is a derived variant of + # MultiReadProcess) by calling the self constructor + short_process = type(self)( + self.start_time, # type: ignore + self.write_port, # type: ignore + short_reads, # type: ignore + self.name, # type: ignore + ) + if long_reads: + # Create a new Process of type self (which is a derived variant of + # MultiReadProcess) by calling the self constructor + long_process = type(self)( + self.start_time, # type: ignore + self.write_port, # type: ignore + long_reads, # type: ignore + self.name, # type: ignore + ) + return short_process, long_process + + def _add_life_time(self, life_time: int): + """ + Add a lifetime to this :class:`~b_asic.process.MultiReadProcess` set of + lifetimes. + + If the lifetime specified by ``life_time`` is already in this + :class:`~b_asic.process.MultiReadProcess`, nothing happens + + After adding a lifetime from this :class:`~b_asic.process.MultiReadProcess`, + the execution time is re-evaluated. + + Parameters + ---------- + life_time : int + The lifetime to add to this :class:`~b_asic.process.MultiReadProcess`. + """ + if life_time not in self.life_times: + self._life_times.append(life_time) + self._execution_time = max(self.life_times) + + def _remove_life_time(self, life_time: int): + """ + Remove a lifetime from this :class:`~b_asic.process.MultiReadProcess` + set of lifetimes. + + After removing a lifetime from this :class:`~b_asic.process.MultiReadProcess`, + the execution time is re-evaluated. + + Raises :class:`KeyError` if the specified lifetime is not a lifetime of this + :class:`~b_asic.process.MultiReadProcess`. + + Parameters + ---------- + life_time : int + The lifetime to remove from this :class:`~b_asic.process.MultiReadProcess`. + """ + if life_time not in self.life_times: + raise KeyError( + f"Process {self.name}: {life_time} not in life_times: {self.life_times}" + ) + else: + self._life_times.remove(life_time) + self._execution_time = max(self.life_times) + + +class MemoryVariable(MemoryProcess): """ Object that corresponds to a memory variable. @@ -130,13 +272,12 @@ class MemoryVariable(Process): reads: Dict[InputPort, int], name: Optional[str] = None, ): - self._read_ports = tuple(reads.keys()) - self._life_times = tuple(reads.values()) + self._read_ports = list(reads.keys()) self._reads = reads self._write_port = write_port super().__init__( - start_time=write_time, - execution_time=max(self._life_times), + write_time=write_time, + life_times=list(reads.values()), name=name, ) @@ -145,11 +286,7 @@ class MemoryVariable(Process): return self._reads @property - def life_times(self) -> Tuple[int, ...]: - return self._life_times - - @property - def read_ports(self) -> Tuple[InputPort, ...]: + def read_ports(self) -> List[InputPort]: return self._read_ports @property @@ -163,12 +300,36 @@ class MemoryVariable(Process): f" {reads!r}, {self.name!r})" ) - @property - def read_times(self) -> Tuple[int, ...]: - return tuple(self.start_time + read for read in self._life_times) + def split_on_length( + self, + length: int = 0, + ) -> Tuple[Optional["MemoryVariable"], Optional["MemoryVariable"]]: + """ + Split this :class:`MemoryVariable` into two new :class:`MemoryVariable` objects, + based on lifetimes of read accesses. + + Parameters + ---------- + length : int, default: 0 + The lifetime length to split on. Length is inclusive for the smaller + process. + + Returns + ------- + Two-tuple where the first element is a :class:`MemoryVariable` consisting + of reads with read times smaller than or equal to ``length`` (or None if no such + reads exists), and vice-versa for the other tuple element. + """ + # This method exists only for documentation purposes and for generating correct + # type annotations when calling it. Just call super().split_on_length() in here. + short_process, long_process = super().split_on_length(length) + return ( + cast(Optional["MemoryVariable"], short_process), + cast(Optional["MemoryVariable"], long_process), + ) -class PlainMemoryVariable(Process): +class PlainMemoryVariable(MemoryProcess): """ Object that corresponds to a memory variable which only use numbers for ports. @@ -196,8 +357,7 @@ class PlainMemoryVariable(Process): reads: Dict[int, int], name: Optional[str] = None, ): - self._read_ports = tuple(reads.keys()) - self._life_times = tuple(reads.values()) + self._read_ports = list(reads.keys()) self._write_port = write_port self._reads = reads if name is None: @@ -205,8 +365,8 @@ class PlainMemoryVariable(Process): PlainMemoryVariable._name_cnt += 1 super().__init__( - start_time=write_time, - execution_time=max(self._life_times), + write_time=write_time, + life_times=list(reads.values()), name=name, ) @@ -215,11 +375,7 @@ class PlainMemoryVariable(Process): return self._reads @property - def life_times(self) -> Tuple[int, ...]: - return self._life_times - - @property - def read_ports(self) -> Tuple[int, ...]: + def read_ports(self) -> List[int]: return self._read_ports @property @@ -233,9 +389,33 @@ class PlainMemoryVariable(Process): f" {reads!r}, {self.name!r})" ) - @property - def read_times(self) -> Tuple[int, ...]: - return tuple(self.start_time + read for read in self._life_times) + def split_on_length( + self, + length: int = 0, + ) -> Tuple[Optional["PlainMemoryVariable"], Optional["PlainMemoryVariable"]]: + """ + Split this :class:`PlainMemoryVariable` into two new + :class:`PlainMemoryVariable` objects, based on lifetimes of read accesses. + + Parameters + ---------- + length : int, default: 0 + The lifetime length to split on. Length is inclusive for the smaller + process. + + Returns + ------- + Two-tuple where the first element is a :class:`PlainMemoryVariable` consisting + of reads with read times smaller than or equal to ``length`` (or None if no such + reads exists), and vice-versa for the other tuple element. + """ + # This method exists only for documentation purposes and for generating correct + # type annotations when calling it. Just call super().split_on_length() in here. + short_process, long_process = super().split_on_length(length) + return ( + cast(Optional["PlainMemoryVariable"], short_process), + cast(Optional["PlainMemoryVariable"], long_process), + ) # Static counter for default names _name_cnt = 0 diff --git a/b_asic/resources.py b/b_asic/resources.py index f5ab1c76f8d611b111f3c57f61fee9c751caba0a..487d172b6da15643c995baa5d81ddbd73493c8cb 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -11,7 +11,13 @@ from matplotlib.ticker import MaxNLocator from b_asic._preferences import LATENCY_COLOR, WARNING_COLOR from b_asic.codegen.vhdl.common import is_valid_vhdl_identifier -from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable, Process +from b_asic.process import ( + MemoryProcess, + MemoryVariable, + OperatorProcess, + PlainMemoryVariable, + Process, +) from b_asic.types import TypeName # Default latency coloring RGB tuple @@ -1272,10 +1278,20 @@ class ProcessCollection: if process.execution_time <= length: short.append(process) else: - long.append(process) - return ProcessCollection( - short, schedule_time=self.schedule_time - ), ProcessCollection(long, schedule_time=self.schedule_time) + if isinstance(process, MemoryProcess): + # Split this MultiReadProcess into two new processes + p_short, p_long = process.split_on_length(length) + if p_short is not None: + short.append(p_short) + if p_long is not None: + long.append(p_long) + else: + # Not a MultiReadProcess: has only a single read + long.append(process) + return ( + ProcessCollection(short, self.schedule_time, self._cyclic), + ProcessCollection(long, self.schedule_time, self._cyclic), + ) def generate_register_based_storage_vhdl( self, diff --git a/test/test_process.py b/test/test_process.py index 213003afd72b920c64344582dc4f3365ad28c1e7..7ed1517957eda987811da881af8c84e2983dfc32 100644 --- a/test/test_process.py +++ b/test/test_process.py @@ -10,8 +10,8 @@ def test_PlainMemoryVariable(): assert mem.write_port == 0 assert mem.start_time == 3 assert mem.execution_time == 2 - assert mem.life_times == (1, 2) - assert mem.read_ports == (4, 5) + assert mem.life_times == [1, 2] + assert mem.read_ports == [4, 5] assert repr(mem) == "PlainMemoryVariable(3, 0, {4: 1, 5: 2}, 'Var. 0')" mem2 = PlainMemoryVariable(2, 0, {4: 2, 5: 3}, 'foo') @@ -39,3 +39,35 @@ def test_MemoryVariables(secondorder_iir_schedule): def test_OperatorProcess_error(secondorder_iir_schedule): with pytest.raises(ValueError, match="does not have an execution time specified"): _ = secondorder_iir_schedule.get_operations() + + +def test_MultiReadProcess(): + mv = PlainMemoryVariable(3, 0, {0: 1, 1: 2, 2: 5}, name="MV") + + with pytest.raises(KeyError, match=r'Process MV: 3 not in life_times: \[1, 2, 5\]'): + mv._remove_life_time(3) + + assert mv.life_times == [1, 2, 5] + assert mv.execution_time == 5 + mv._remove_life_time(5) + assert mv.life_times == [1, 2] + assert mv.execution_time == 2 + mv._add_life_time(4) + assert mv.execution_time == 4 + assert mv.life_times == [1, 2, 4] + mv._add_life_time(4) + assert mv.life_times == [1, 2, 4] + + +def test_split_on_length(): + mv = PlainMemoryVariable(3, 0, {0: 1, 1: 2, 2: 5}, name="MV") + short, long = mv.split_on_length(2) + assert short is not None and long is not None + assert short.start_time == 3 and long.start_time == 3 + assert short.execution_time == 2 and long.execution_time == 5 + assert short.reads == {0: 1, 1: 2} + assert long.reads == {2: 5} + + short, long = mv.split_on_length(0) + assert short is None + assert long is not None diff --git a/test/test_resources.py b/test/test_resources.py index 7b05a7bdd3b0fd54f540cfd3ff47f01302b4651d..b27755f8fa632a4da2ad209dffe57e22bdb31bf1 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -232,3 +232,28 @@ class TestProcessCollectionPlainMemoryVariable: match="MV0 has execution time greater than the schedule time", ): collection.split_on_execution_time(heuristic) + + def test_split_on_length(self): + # Test 1: Exclude a zero-time access time + collection = ProcessCollection( + collection=[PlainMemoryVariable(0, 1, {0: 1, 1: 2, 2: 3})], + schedule_time=4, + ) + short, long = collection.split_on_length(0) + assert len(short) == 0 and len(long) == 1 + for split_time in [1, 2]: + short, long = collection.split_on_length(split_time) + assert len(short) == 1 and len(long) == 1 + short, long = collection.split_on_length(3) + assert len(short) == 1 and len(long) == 0 + + # Test 2: Include a zero-time access time + collection = ProcessCollection( + collection=[PlainMemoryVariable(0, 1, {0: 0, 1: 1, 2: 2, 3: 3})], + schedule_time=4, + ) + short, long = collection.split_on_length(0) + assert len(short) == 1 and len(long) == 1 + for split_time in [1, 2]: + short, long = collection.split_on_length(split_time) + assert len(short) == 1 and len(long) == 1