diff --git a/b_asic/operation.py b/b_asic/operation.py index 21e7012eaf7a333c5db8a7f8a6c741b3220030b8..02ba1aa50682448e931a0694d24e03e20eadd399 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -193,6 +193,7 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. TODO: More info. diff --git a/b_asic/port.py b/b_asic/port.py index 59a218d9f8aa288d0aacb9dea15ca2cf0a604355..20783d5df0962b034aee2b6e934255a9fc9cd6e6 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -128,7 +128,7 @@ class InputPort(AbstractPort): signal.set_destination(self) def remove_signal(self, signal: Signal) -> None: - assert signal is self._source_signal, "Attempted to remove already removed signal." + assert signal is self._source_signal, "Attempted to remove signal that is not connected." self._source_signal = None signal.remove_destination() @@ -177,7 +177,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): signal.set_source(self) def remove_signal(self, signal: Signal) -> None: - assert signal in self._destination_signals, "Attempted to remove already removed signal." + assert signal in self._destination_signals, "Attempted to remove signal that is not connected." self._destination_signals.remove(signal) signal.remove_source() diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 10a383bd7c0e8f300ef0ac7f045fd500e5e7af92..79b539cc4f037227a2ec3debe860618538c1ccb8 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -258,7 +258,7 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations - + def to_sfg(self) -> 'SFG': return self @@ -480,7 +480,7 @@ class SFG(AbstractOperation): # The old SFG will be deleted by Python GC return _sfg_copy() - def insert_operation(self, component: Operation, output_comp_id: GraphID): + def insert_operation(self, component: Operation, output_comp_id: GraphID) -> Optional["SFG"]: """Insert an operation in the SFG after a given source operation. The source operation output count must match the input count of the operation as well as the output Then return a new deepcopy of the sfg with the inserted component. @@ -511,6 +511,37 @@ class SFG(AbstractOperation): # Recreate the newly coupled SFG so that all attributes are correct. return sfg_copy() + def remove_operation(self, operation_id: GraphID) -> "SFG": + """Returns a version of the SFG where the operation with the specified GraphID removed. + The operation has to have the same amount of input- and output ports or a ValueError will + be raised. If no operation with the entered operation_id is found then returns None and does nothing.""" + sfg_copy = self() + operation = sfg_copy.find_by_id(operation_id) + if operation is None: + return None + + if operation.input_count != operation.output_count: + raise ValueError("Different number of input and output ports of operation with the specified id") + + for i, outport in enumerate(operation.outputs): + if outport.signal_count > 0: + if operation.input(i).signal_count > 0 and operation.input(i).signals[0].source is not None: + in_sig = operation.input(i).signals[0] + source_port = in_sig.source + source_port.remove_signal(in_sig) + operation.input(i).remove_signal(in_sig) + for out_sig in outport.signals.copy(): + out_sig.set_source(source_port) + else: + for out_sig in outport.signals.copy(): + out_sig.remove_source() + else: + if operation.input(i).signal_count > 0: + in_sig = operation.input(i).signals[0] + operation.input(i).remove_signal(in_sig) + + return sfg_copy() + def _evaluate_source(self, src: OutputPort, results: MutableOutputMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix if src_prefix: diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index e2145b0a2a5974222c8c3d740cb3f53d76c7e445..08d9e8aa2bacd0b1c1a11c17c174179d853e6ed7 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -113,8 +113,9 @@ def simple_filter(): in1 = Input("IN1") constmul1 = ConstantMultiplication(0.5, name="CMUL1") add1 = Addition(in1, constmul1, "ADD1") + add1.input(1).signals[0].name = "S2" reg = Register(add1, name="REG1") - constmul1.input(0).connect(reg) + constmul1.input(0).connect(reg, "S1") out1 = Output(reg, "OUT1") return SFG(inputs=[in1], outputs=[out1], name="simple_filter") diff --git a/test/test_sfg.py b/test/test_sfg.py index b6625766b2f17ea142a875ffa6f767786cc4ad94..6f0a7ec40ef0042f7a0a5867467d662a604d5b43 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -707,3 +707,72 @@ class TestTopologicalOrderOperations: topological_order = sfg_two_inputs_two_outputs_independent.get_operations_topological_order() assert [comp.name for comp in topological_order] == ["IN1", "OUT1", "IN2", "C1", "ADD1", "OUT2"] + + +class TestRemove: + def test_remove_single_input_outputs(self, simple_filter): + new_sfg = simple_filter.remove_operation("cmul1") + + assert set(op.name for op in simple_filter.find_by_name("REG1")[0].subsequent_operations) == {"CMUL1", "OUT1"} + assert set(op.name for op in new_sfg.find_by_name("REG1")[0].subsequent_operations) == {"ADD1", "OUT1"} + + assert set(op.name for op in simple_filter.find_by_name("ADD1")[0].preceding_operations) == {"CMUL1", "IN1"} + assert set(op.name for op in new_sfg.find_by_name("ADD1")[0].preceding_operations) == {"REG1", "IN1"} + + assert "S1" in set([sig.name for sig in simple_filter.find_by_name("REG1")[0].output(0).signals]) + assert "S2" in set([sig.name for sig in new_sfg.find_by_name("REG1")[0].output(0).signals]) + + def test_remove_multiple_inputs_outputs(self, butterfly_operation_tree): + out1 = Output(butterfly_operation_tree.output(0), "OUT1") + out2 = Output(butterfly_operation_tree.output(1), "OUT2") + + sfg = SFG(outputs=[out1, out2]) + + new_sfg = sfg.remove_operation(sfg.find_by_name("bfly2")[0].graph_id) + + assert sfg.find_by_name("bfly3")[0].output(0).signal_count == 1 + assert new_sfg.find_by_name("bfly3")[0].output(0).signal_count == 1 + + sfg_dest_0 = sfg.find_by_name("bfly3")[0].output(0).signals[0].destination + new_sfg_dest_0 = new_sfg.find_by_name("bfly3")[0].output(0).signals[0].destination + + assert sfg_dest_0.index == 0 + assert new_sfg_dest_0.index == 0 + assert sfg_dest_0.operation.name == "bfly2" + assert new_sfg_dest_0.operation.name == "bfly1" + + assert sfg.find_by_name("bfly3")[0].output(1).signal_count == 1 + assert new_sfg.find_by_name("bfly3")[0].output(1).signal_count == 1 + + sfg_dest_1 = sfg.find_by_name("bfly3")[0].output(1).signals[0].destination + new_sfg_dest_1 = new_sfg.find_by_name("bfly3")[0].output(1).signals[0].destination + + assert sfg_dest_1.index == 1 + assert new_sfg_dest_1.index == 1 + assert sfg_dest_1.operation.name == "bfly2" + assert new_sfg_dest_1.operation.name == "bfly1" + + assert sfg.find_by_name("bfly1")[0].input(0).signal_count == 1 + assert new_sfg.find_by_name("bfly1")[0].input(0).signal_count == 1 + + sfg_source_0 = sfg.find_by_name("bfly1")[0].input(0).signals[0].source + new_sfg_source_0 = new_sfg.find_by_name("bfly1")[0].input(0).signals[0].source + + assert sfg_source_0.index == 0 + assert new_sfg_source_0.index == 0 + assert sfg_source_0.operation.name == "bfly2" + assert new_sfg_source_0.operation.name == "bfly3" + + sfg_source_1 = sfg.find_by_name("bfly1")[0].input(1).signals[0].source + new_sfg_source_1 = new_sfg.find_by_name("bfly1")[0].input(1).signals[0].source + + assert sfg_source_1.index == 1 + assert new_sfg_source_1.index == 1 + assert sfg_source_1.operation.name == "bfly2" + assert new_sfg_source_1.operation.name == "bfly3" + + assert "bfly2" not in set(op.name for op in new_sfg.operations) + + def remove_different_number_inputs_outputs(self, simple_filter): + with pytest.raises(ValueError): + simple_filter.remove_operation("add1")