From 7f64ae990f1b2ed6a1395f449b8f0b11e55f288d Mon Sep 17 00:00:00 2001 From: Robier Al Kaadi <robal695@student.liu.se> Date: Mon, 15 Jul 2024 10:45:04 +0000 Subject: [PATCH] Resolve "Incorrect SFGs do not raise" --- b_asic/signal_flow_graph.py | 27 ++++++++++++++++++++++- test/test_core_operations.py | 8 +++---- test/test_sfg.py | 42 ++++++++++++++++++++++++++++++++---- 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 44979196..cbd491d7 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -6,6 +6,7 @@ Contains the signal flow graph operation. import itertools import re +import warnings from collections import defaultdict, deque from io import StringIO from numbers import Number @@ -66,10 +67,11 @@ class GraphIDGenerator: def next_id(self, type_name: TypeName, used_ids: MutableSet = set()) -> GraphID: """Get the next graph id for a certain graph id type.""" new_id = type_name + str(self._next_id_number[type_name]) - self._next_id_number[type_name] += 1 + self._next_id_number[type_name] = 0 while new_id in used_ids: self._next_id_number[type_name] += 1 new_id = type_name + str(self._next_id_number[type_name]) + used_ids.add(GraphID(new_id)) return GraphID(new_id) @property @@ -281,12 +283,18 @@ class SFG(AbstractOperation): ) # Search the graph inwards from each output signal. + output_sources = [] for ( signal, output_index, ) in self._original_output_signals_to_indices.items(): # Check if already added source. new_signal = cast(Signal, self._original_components_to_new[signal]) + + if new_signal.source in output_sources: + warnings.warn("Two signals connected to the same output port") + output_sources.append(new_signal.source) + if new_signal.source is None: if signal.source is None: raise ValueError( @@ -295,6 +303,11 @@ class SFG(AbstractOperation): if signal.source.operation not in self._original_components_to_new: self._add_operation_connected_tree_copy(signal.source.operation) + if len(output_sources) != (output_operation_count + output_signal_count): + raise ValueError( + "At least one output operation is not connected!, Tips: Check for output ports that are connected to the same signal" + ) + def __str__(self) -> str: """Return a string representation of this SFG.""" string_io = StringIO() @@ -639,6 +652,12 @@ class SFG(AbstractOperation): signal.remove_source() signal.set_source(component.output(index_out)) + if component_copy.type_name() == 'out': + sfg_copy._output_operations.remove(component_copy) + warnings.warn(f"Output port {component_copy.graph_id} has been removed") + if component.type_name() == 'out': + sfg_copy._output_operations.append(component) + return sfg_copy() # Copy again to update IDs. def insert_operation( @@ -2097,3 +2116,9 @@ class SFG(AbstractOperation): ret = list({op.type_name() for op in self.operations}) ret.sort() return ret + + def get_used_graph_ids(self) -> Set[GraphID]: + """Get a list of all GraphID:s used in the SFG.""" + ret = set({op.graph_id for op in self.operations}) + sorted(ret) + return ret diff --git a/test/test_core_operations.py b/test/test_core_operations.py index f9d13fd6..6ed00f45 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -3,7 +3,6 @@ import pytest from b_asic import ( - SFG, Absolute, Addition, AddSub, @@ -415,9 +414,8 @@ class TestSink: sfg = bfly.to_sfg() s = Sink() sfg1 = sfg.replace_operation(s, "out0") - sfg2 = SFG(sfg1.input_operations, sfg1.output_operations[1:]) - assert sfg2.output_count == 1 - assert sfg2.input_count == 2 + assert sfg1.output_count == 1 + assert sfg1.input_count == 2 - assert sfg.evaluate_output(1, [0, 1]) == sfg2.evaluate_output(0, [0, 1]) + assert sfg.evaluate_output(1, [0, 1]) == sfg1.evaluate_output(0, [0, 1]) diff --git a/test/test_sfg.py b/test/test_sfg.py index eae9c403..2c1b8f11 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1351,8 +1351,10 @@ class TestSFGErrors: adaptor = SymmetricTwoportAdaptor(0.5, in1, in2) out1 = Output(adaptor.output(0)) out2 = Output() - # No error, should be - SFG([in1, in2], [out1, out2]) + with pytest.raises( + ValueError, match="At least one output operation is not connected!, Tips: Check for output ports that are connected to the same signal" + ): + SFG([in1, in2], [out1, out2]) def test_unconnected_input(self): in1 = Input() @@ -1421,8 +1423,10 @@ class TestSFGErrors: adaptor = SymmetricTwoportAdaptor(0.5, in1, in2) out1 = Output(adaptor.output(0)) signal = Signal(adaptor.output(1)) - # Should raise? - SFG([in1, in2], [out1], output_signals=[signal, signal]) + with pytest.raises( + ValueError, match="At least one output operation is not connected!, Tips: Check for output ports that are connected to the same signal" + ): + SFG([in1, in2], [out1], output_signals=[signal, signal]) def test_dangling_input_signal(self): in1 = Input() @@ -1745,6 +1749,36 @@ class TestGetUsedTypeNames: assert sfg.get_used_type_names() == ['add', 'c', 'out'] +class Test_Keep_GraphIDs: + def test_single_accumulator(self): + + i = Input() + d = Delay() + o = Output(d) + c = ConstantMultiplication(0.5, d) + a = Addition(i, c) + d.input(0).connect(a) + + sfg = SFG([i], [o]) + sfg = sfg.insert_operation_before('t0', ConstantMultiplication(8)) + sfg = sfg.insert_operation_after('t0', ConstantMultiplication(8)) + sfg = sfg.insert_operation(ConstantMultiplication(8), 't0') + assert sfg.get_used_graph_ids() == { + 'add0', + 'cmul0', + 'cmul1', + 'cmul2', + 'cmul3', + 'in0', + 'out0', + 't0', + } + + def test_large_operation_tree(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + assert sfg.get_used_type_names() == ['add', 'c', 'out'] + + class TestInsertDelays: def test_insert_delays_before_operation(self): in1 = Input() -- GitLab