From 6717f35553edd91960252d1bd3c15b5a2df7abfe Mon Sep 17 00:00:00 2001 From: Mikael Henriksson <mike.zx@hotmail.com> Date: Mon, 15 May 2023 19:25:38 +0200 Subject: [PATCH] architecture.py: add support for moving Processes between Resources --- b_asic/architecture.py | 101 ++++++++++++++++++++++++++++++++++++-- b_asic/resources.py | 18 +++++++ test/test_architecture.py | 88 +++++++++++++++++++++++++++++++-- test/test_resources.py | 7 +++ 4 files changed, 206 insertions(+), 8 deletions(-) diff --git a/b_asic/architecture.py b/b_asic/architecture.py index 480da9c4..c60b2382 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -3,6 +3,7 @@ B-ASIC architecture classes. """ from collections import defaultdict from io import TextIOWrapper +from itertools import chain from typing import ( DefaultDict, Dict, @@ -12,6 +13,7 @@ from typing import ( Optional, Set, Tuple, + Type, Union, cast, ) @@ -20,8 +22,9 @@ import matplotlib.pyplot as plt from graphviz import Digraph from b_asic.codegen.vhdl.common import is_valid_vhdl_identifier +from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort -from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable +from b_asic.process import MemoryProcess, MemoryVariable, OperatorProcess, Process from b_asic.resources import ProcessCollection @@ -249,6 +252,36 @@ class Resource(HardwareBlock): def collection(self) -> ProcessCollection: return self._collection + @property + def operation_type(self) -> Union[Type[MemoryProcess], Type[OperatorProcess]]: + raise NotImplementedError("ABC Resource does not implement operation_type") + + def add_process(self, proc: Process): + """ + Add a :class:`~b_asic.process.Process` to this :class:`Resource`. + + Raises :class:`KeyError` if the process being added is not of the same type + as the other processes. + + Parameters + ---------- + proc : :class:`~b_asic.process.Process` + The process to add. + """ + if isinstance(proc, OperatorProcess): + # operation_type marks OperatorProcess associated operation. + if not isinstance(proc._operation, self.operation_type): + raise KeyError(f"{proc} not of type {self.operation_type}") + else: + # operation_type is MemoryVariable or PlainMemoryVariable + if not isinstance(proc, self.operation_type): + raise KeyError(f"{proc} not of type {self.operation_type}") + self.collection.add_process(proc) + + def remove_process(self, proc): + self.collection.remove_process(proc) + self._assignment = None + class ProcessingElement(Resource): """ @@ -310,6 +343,10 @@ class ProcessingElement(Resource): self._assignment = None raise ValueError("Cannot map ProcessCollection to single ProcessingElement") + @property + def operation_type(self) -> Type[Operation]: + return self._operation_type + class Memory(Resource): """ @@ -339,12 +376,11 @@ class Memory(Resource): ): super().__init__(process_collection=process_collection, entity_name=entity_name) if not all( - isinstance(operator, (MemoryVariable, PlainMemoryVariable)) + isinstance(operator, MemoryProcess) for operator in process_collection.collection ): raise TypeError( - "Can only have MemoryVariable or PlainMemoryVariable in" - " ProcessCollection when creating Memory" + "Can only have MemoryProcess in ProcessCollection when creating Memory" ) if memory_type not in ("RAM", "register"): raise ValueError( @@ -366,6 +402,14 @@ class Memory(Resource): self._input_count = write_ports self._memory_type = memory_type + memory_processes = [ + cast(MemoryProcess, process) for process in process_collection + ] + mem_proc_type = type(memory_processes[0]) + if not all(isinstance(proc, mem_proc_type) for proc in memory_processes): + raise TypeError("Different MemoryProcess types in ProcessCollection") + self._operation_type = mem_proc_type + def __iter__(self) -> Iterator[MemoryVariable]: # Add information about the iterator type return cast(Iterator[MemoryVariable], iter(self._collection)) @@ -393,6 +437,10 @@ class Memory(Resource): else: # "register" raise NotImplementedError() + @property + def operation_type(self) -> Type[MemoryProcess]: + return self._operation_type + class Architecture(HardwareBlock): """ @@ -577,6 +625,51 @@ of :class:`~b_asic.architecture.ProcessingElement` d_out[i][self._variable_outport_to_resource[output]] += 1 return [dict(d) for d in d_in], [dict(d) for d in d_out] + def resource_from_name(self, name: str): + re = {p.entity_name: p for p in chain(self.memories, self.processing_elements)} + return re[name] + + def move_process( + self, + proc: Union[str, Process], + re_from: Union[str, Resource], + re_to: Union[str, Resource], + ): + """ + Move a :class:`b_asic.process.Process` from one resource to another in the + architecture. + + Both the resource moved from and will become unassigned after a process has been + moved. + + Raises :class:`KeyError` if ``proc`` is not present in resource ``re_from``. + + Parameters + ---------- + proc : :class:`b_asic.process.Process` or string + The process (or its given name) to move. + re_from : :class:`b_asic.architecture.Resource` or string + The resource (or its given name) to move the process from. + re_to : :class:`b_asic.architecture.Resource` or string + The resource (or its given name) to move the process to. + """ + # Extract resouces from name + if isinstance(re_from, str): + re_from = self.resource_from_name(re_from) + if isinstance(re_to, str): + re_to = self.resource_from_name(re_to) + + # Extract process from name + if isinstance(proc, str): + proc = re_from.collection.from_name(proc) + + # Move the process. + if proc not in re_from.collection: + raise KeyError(f"{proc} not in {re_from}") + else: + re_to.add_process(proc) + re_from.remove_process(proc) + def _digraph(self) -> Digraph: edges: Set[Tuple[str, str, str]] = set() dg = Digraph(node_attr={'shape': 'record'}) diff --git a/b_asic/resources.py b/b_asic/resources.py index 487d172b..d655175c 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -1427,3 +1427,21 @@ class ProcessCollection: writes = [process.start_time for process in self._collection] count = Counter(writes) return max(count.values()) + + def from_name(self, name: str): + """ + Get a :class:`~b_asic.process.Process` from this collection from its name. + + Raises :class:`KeyError` if no processes with ``name`` is found in this + colleciton. + + Parameters + ---------- + name : str + The name of the process to retrieve. + """ + name_to_proc = {p.name: p for p in self.collection} + if name not in name_to_proc: + raise KeyError(f'{name} not in {self}') + else: + return name_to_proc[name] diff --git a/test/test_architecture.py b/test/test_architecture.py index 8b52b070..48e4cc31 100644 --- a/test/test_architecture.py +++ b/test/test_architecture.py @@ -5,7 +5,7 @@ import pytest from b_asic.architecture import Architecture, Memory, ProcessingElement from b_asic.core_operations import Addition, ConstantMultiplication -from b_asic.process import MemoryVariable, OperatorProcess +from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable from b_asic.resources import ProcessCollection from b_asic.schedule import Schedule from b_asic.special_operations import Input, Output @@ -25,6 +25,24 @@ def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Sched ProcessingElement(empty_collection) +def test_add_remove_process_from_resource(schedule_direct_form_iir_lp_filter: Schedule): + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + operations = schedule_direct_form_iir_lp_filter.get_operations() + memory = Memory(mvs) + pe = ProcessingElement( + operations.get_by_type_name(ConstantMultiplication.type_name()) + ) + for process in operations: + with pytest.raises(KeyError, match=f"{process} not of type"): + memory.add_process(process) + for process in mvs: + with pytest.raises(KeyError, match=f"{process} not of type"): + pe.add_process(process) + + with pytest.raises(KeyError, match="PlainMV not of type"): + memory.add_process(PlainMemoryVariable(0, 0, {0: 2}, "PlainMV")) + + def test_extract_processing_elements(schedule_direct_form_iir_lp_filter: Schedule): # Extract operations from schedule operations = schedule_direct_form_iir_lp_filter.get_operations() @@ -53,9 +71,7 @@ def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule): ValueError, match="Do not create Resource with empty ProcessCollection" ): Memory(empty_collection) - with pytest.raises( - TypeError, match="Can only have MemoryVariable or PlainMemoryVariable" - ): + with pytest.raises(TypeError, match="Can only have MemoryProcess"): Memory(operations) # No exception Memory(mvs) @@ -137,3 +153,67 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule): mv = cast(MemoryVariable, mv) print(f' {mv.start_time} -> {mv.execution_time}: {mv.write_port.name}') print(architecture.get_interconnects_for_memory(memory)) + + +def test_move_process(schedule_direct_form_iir_lp_filter: Schedule): + # Resources + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + operations = schedule_direct_form_iir_lp_filter.get_operations() + adders1, adders2 = operations.get_by_type_name(Addition.type_name()).split_on_ports( + total_ports=1 + ) + adders1 = [adders1] # Fake two PEs needed for the adders + adders2 = [adders2] # Fake two PEs needed for the adders + const_mults = operations.get_by_type_name( + ConstantMultiplication.type_name() + ).split_on_execution_time() + inputs = operations.get_by_type_name(Input.type_name()).split_on_execution_time() + outputs = operations.get_by_type_name(Output.type_name()).split_on_execution_time() + + # Create necessary processing elements + processing_elements: List[ProcessingElement] = [ + ProcessingElement(operation, entity_name=f'pe{i}') + for i, operation in enumerate(chain(adders1, adders2, const_mults)) + ] + for i, pc in enumerate(inputs): + processing_elements.append(ProcessingElement(pc, entity_name=f'input{i}')) + for i, pc in enumerate(outputs): + processing_elements.append(ProcessingElement(pc, entity_name=f'output{i}')) + + # Extract zero-length memory variables + direct_conn, mvs = mvs.split_on_length() + + # Create Memories from the memory variables (split on length to get two memories) + memories: List[Memory] = [Memory(pc) for pc in mvs.split_on_length(6)] + + # Create architecture + architecture = Architecture( + processing_elements, memories, direct_interconnects=direct_conn + ) + + # Some movement that must work + assert memories[1].collection.from_name('cmul4.0') + architecture.move_process('cmul4.0', memories[1], memories[0]) + assert memories[0].collection.from_name('cmul4.0') + + assert memories[1].collection.from_name('in1.0') + architecture.move_process('in1.0', memories[1], memories[0]) + assert memories[0].collection.from_name('in1.0') + + assert processing_elements[1].collection.from_name('add1') + architecture.move_process('add1', processing_elements[1], processing_elements[0]) + assert processing_elements[0].collection.from_name('add1') + + # Processes leave the resources they have moved from + with pytest.raises(KeyError): + memories[1].collection.from_name('cmul4.0') + with pytest.raises(KeyError): + memories[1].collection.from_name('in1.0') + with pytest.raises(KeyError): + processing_elements[1].collection.from_name('add1') + + # Processes can only be moved when the source and destination process-types match + with pytest.raises(KeyError, match="cmul4.0 not of type"): + architecture.move_process('cmul4.0', memories[0], processing_elements[0]) + with pytest.raises(KeyError, match="invalid_name not in"): + architecture.move_process('invalid_name', memories[0], processing_elements[1]) diff --git a/test/test_resources.py b/test/test_resources.py index b27755f8..df034037 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -257,3 +257,10 @@ class TestProcessCollectionPlainMemoryVariable: for split_time in [1, 2]: short, long = collection.split_on_length(split_time) assert len(short) == 1 and len(long) == 1 + + def test_from_name(self): + a = PlainMemoryVariable(0, 0, {0: 2}, name="cool name 1337") + collection = ProcessCollection([a], schedule_time=5, cyclic=True) + with pytest.raises(KeyError, match="epic_name not in ..."): + collection.from_name("epic_name") + assert a == collection.from_name("cool name 1337") -- GitLab