diff --git a/b_asic/schedule.py b/b_asic/schedule.py index ef0a677274edd39ab8698ad5a4fa316865707448..99ed184cd4a127c97e828ef8d816ea12bf5e1626 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -764,9 +764,13 @@ class Schedule: Reintroduce delay elements to each signal according to the ``_laps`` variable. """ new_sfg = self._sfg() + destination_laps = [] for signal_id,lap in self._laps.items(): + port = new_sfg.find_by_id(signal_id).destination + destination_laps.append((port.operation.graph_id, port.index, lap)) + for op,port,lap in destination_laps: for delays in range(lap): - new_sfg = new_sfg.insert_operation_after(signal_id, Delay()) + new_sfg = new_sfg.insert_operation_before(op, Delay(), port) return new_sfg() def _schedule_alap(self) -> None: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index c2e3b43eba25056a2ad855e9d99f9a8d7dad4ee6..befbd67014d22f2790fbdc726250b6d7a11c0bf8 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -736,6 +736,56 @@ class SFG(AbstractOperation): # Recreate the newly coupled SFG so that all attributes are correct. return sfg_copy() + def insert_operation_before( + self, + input_comp_id: GraphID, + new_operation: Operation, + port: Optional[int] = None + ) -> Optional["SFG"]: + """ + Insert an operation in the SFG before a given source operation. + + Then return a new deepcopy of the sfg with the inserted component. + + The graph_id can be an Operation or a Signal. If the operation has multiple + inputs, (copies of) the same operation will be inserted on every port. + To specify a port use the ``port`` parameter. + + Currently, the new operation must have one input and one output. + + Parameters + ---------- + input_comp_id : GraphID + The source operation GraphID to connect to. + new_operation : Operation + The new operation, e.g. Multiplication. + port : Optional[int] + The number of the InputPort before which the new operation shall be inserted + """ + + # Preserve the original SFG by creating a copy. + sfg_copy = self() + if new_operation.output_count != 1 or new_operation.input_count != 1: + raise TypeError( + "Only operations with one input and one output can be inserted." + ) + + input_comp = sfg_copy.find_by_id(input_comp_id) + if input_comp is None: + raise ValueError(f"Unknown component: {input_comp_id!r}") + if isinstance(input_comp, Operation): + if port is None: + sfg_copy._insert_operation_before_operation(input_comp, new_operation) + else: + sfg_copy._insert_operation_before_inputport( + input_comp.input(port), new_operation + ) + elif isinstance(input_comp, Signal): + sfg_copy._insert_operation_after_signal(input_comp, new_operation) + + # Recreate the newly coupled SFG so that all attributes are correct. + return sfg_copy() + def simplify_delay_element_placement(self) -> "SFG": """ Simplify an SFG by removing some redundant delay elements. @@ -746,24 +796,30 @@ class SFG(AbstractOperation): """ sfg_copy = self() - for delay_element in sfg_copy.find_by_type_name(Delay.type_name()): - neighboring_delays = [] - if len(delay_element.inputs[0].signals) > 0: - for signal in delay_element.inputs[0].signals[0].source.signals: - if isinstance(signal.destination.operation, Delay): - neighboring_delays.append(signal.destination.operation) - - if delay_element in neighboring_delays: - neighboring_delays.remove(delay_element) - - for delay in neighboring_delays: - for output in delay.outputs[0].signals: - output.set_source(delay_element.outputs[0]) - in_sig = delay.input(0).signals[0] - delay.input(0).remove_signal(in_sig) - in_sig.source.remove_signal(in_sig) - - return sfg_copy() + no_of_delays = len(sfg_copy.find_by_type_name(Delay.type_name())) + while True: + for delay_element in sfg_copy.find_by_type_name(Delay.type_name()): + neighboring_delays = [] + if len(delay_element.inputs[0].signals) > 0: + for signal in delay_element.inputs[0].signals[0].source.signals: + if isinstance(signal.destination.operation, Delay): + neighboring_delays.append(signal.destination.operation) + + if delay_element in neighboring_delays: + neighboring_delays.remove(delay_element) + + for delay in neighboring_delays: + for output in delay.outputs[0].signals: + output.set_source(delay_element.outputs[0]) + in_sig = delay.input(0).signals[0] + delay.input(0).remove_signal(in_sig) + in_sig.source.remove_signal(in_sig) + sfg_copy = sfg_copy() + if no_of_delays <= len(sfg_copy.find_by_type_name(Delay.type_name())): + break + no_of_delays = len(sfg_copy.find_by_type_name(Delay.type_name())) + + return sfg_copy def _insert_operation_after_operation( self, output_operation: Operation, new_operation: Operation @@ -771,6 +827,12 @@ class SFG(AbstractOperation): for output in output_operation.outputs: self._insert_operation_after_outputport(output, new_operation.copy()) + def _insert_operation_before_operation( + self, input_operation: Operation, new_operation: Operation + ): + for port in input_operation.inputs: + self._insert_operation_before_inputport(port, new_operation.copy()) + def _insert_operation_after_outputport( self, output_port: OutputPort, new_operation: Operation ): @@ -780,12 +842,25 @@ class SFG(AbstractOperation): signal.set_source(new_operation) new_operation.input(0).connect(output_port) + def _insert_operation_before_inputport( + self, input_port: InputPort, new_operation: Operation + ): + # Make copy as list will be updated + input_port.signals[0].set_destination(new_operation) + new_operation.output(0).add_signal(Signal(destination=input_port)) + def _insert_operation_before_signal(self, signal: Signal, new_operation: Operation): output_port = signal.source output_port.remove_signal(signal) Signal(output_port, new_operation) signal.set_source(new_operation) + def _insert_operation_after_signal(self, signal: Signal, new_operation: Operation): + input_port = signal.destination + input_port.remove_signal(signal) + Signal(new_operation, input_port) + signal.set_destination(new_operation) + def swap_io_of_operation(self, operation_id: GraphID) -> None: """ Swap the inputs (and outputs) of operation. diff --git a/test/test_schedule.py b/test/test_schedule.py index aba52d2f1d39e52a490fb6016c8d752b2b6aad30..d9e4fe175818b89a803dd003acf4ffd480f3ebc7 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -11,6 +11,7 @@ from b_asic.process import OperatorProcess from b_asic.schedule import Schedule from b_asic.signal_flow_graph import SFG from b_asic.special_operations import Delay, Input, Output +from b_asic.sfg_generators import direct_form_fir class TestInit: @@ -534,9 +535,20 @@ class TestRescheduling: sfg = schedule.sfg assert sfg_direct_form_iir_lp_filter.evaluate(5) == sfg.evaluate(5) - - - + fir_sfg = direct_form_fir( + list(range(1, 10)), + mult_properties={ + 'latency': 2, + 'execution_time': 1 + }, + add_properties={ + 'latency': 2, + 'execution_time': 1 + } + ) + schedule = Schedule(fir_sfg, algorithm="ASAP") + sfg = schedule.sfg + assert fir_sfg.evaluate(5) == sfg.evaluate(5) class TestTimeResolution: def test_increase_time_resolution( diff --git a/test/test_sfg.py b/test/test_sfg.py index cf97709c981de3e1e5838814328f1aaaf462efb4..1bd6d9fddab6f66a721182d89fad642546616342 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1630,6 +1630,57 @@ class TestInsertComponentAfter: with pytest.raises(ValueError, match="Unknown component:"): sfg.insert_operation_after('foo', SquareRoot()) +class TestInsertComponentBefore: + def test_insert_component_before_in_sfg(self, butterfly_operation_tree): + sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs))) + sqrt = SquareRoot() + + _sfg = sfg.insert_operation_before( + sfg.find_by_name("bfly1")[0].graph_id, sqrt, port=0 + ) + assert _sfg.evaluate() != sfg.evaluate() + + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) + assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) + + assert not isinstance( + sfg.find_by_name("bfly1")[0].input(0).signals[0].source.operation, + SquareRoot, + ) + assert isinstance( + _sfg.find_by_name("bfly1")[0] + .input(0) + .signals[0] + .source.operation, + SquareRoot, + ) + + assert sfg.find_by_name("bfly1")[0].input(0).signals[ + 0 + ].source.operation is sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("bfly1")[0].input(0).signals[ + 0 + ].destination.operation is not _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_id("sqrt0").input(0).signals[ + 0 + ].source.operation is _sfg.find_by_name("bfly2")[0] + + def test_insert_component_before_mimo_operation_error( + self, large_operation_tree_names + ): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) + with pytest.raises( + TypeError, match="Only operations with one input and one output" + ): + sfg.insert_operation_before('add0', SymmetricTwoportAdaptor(0.5), port=0) + + def test_insert_component_before_unknown_component_error( + self, large_operation_tree_names + ): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) + with pytest.raises(ValueError, match="Unknown component:"): + sfg.insert_operation_before('foo', SquareRoot()) + class TestGetUsedTypeNames: def test_single_accumulator(self, sfg_simple_accumulator: SFG):