From 7f64ae990f1b2ed6a1395f449b8f0b11e55f288d Mon Sep 17 00:00:00 2001
From: Robier Al Kaadi <robal695@student.liu.se>
Date: Mon, 15 Jul 2024 10:45:04 +0000
Subject: [PATCH] Resolve "Incorrect SFGs do not raise"

---
 b_asic/signal_flow_graph.py  | 27 ++++++++++++++++++++++-
 test/test_core_operations.py |  8 +++----
 test/test_sfg.py             | 42 ++++++++++++++++++++++++++++++++----
 3 files changed, 67 insertions(+), 10 deletions(-)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 44979196..cbd491d7 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -6,6 +6,7 @@ Contains the signal flow graph operation.
 
 import itertools
 import re
+import warnings
 from collections import defaultdict, deque
 from io import StringIO
 from numbers import Number
@@ -66,10 +67,11 @@ class GraphIDGenerator:
     def next_id(self, type_name: TypeName, used_ids: MutableSet = set()) -> GraphID:
         """Get the next graph id for a certain graph id type."""
         new_id = type_name + str(self._next_id_number[type_name])
-        self._next_id_number[type_name] += 1
+        self._next_id_number[type_name] = 0
         while new_id in used_ids:
             self._next_id_number[type_name] += 1
             new_id = type_name + str(self._next_id_number[type_name])
+        used_ids.add(GraphID(new_id))
         return GraphID(new_id)
 
     @property
@@ -281,12 +283,18 @@ class SFG(AbstractOperation):
                 )
 
         # Search the graph inwards from each output signal.
+        output_sources = []
         for (
             signal,
             output_index,
         ) in self._original_output_signals_to_indices.items():
             # Check if already added source.
             new_signal = cast(Signal, self._original_components_to_new[signal])
+
+            if new_signal.source in output_sources:
+                warnings.warn("Two signals connected to the same output port")
+            output_sources.append(new_signal.source)
+
             if new_signal.source is None:
                 if signal.source is None:
                     raise ValueError(
@@ -295,6 +303,11 @@ class SFG(AbstractOperation):
                 if signal.source.operation not in self._original_components_to_new:
                     self._add_operation_connected_tree_copy(signal.source.operation)
 
+        if len(output_sources) != (output_operation_count + output_signal_count):
+            raise ValueError(
+                "At least one output operation is not connected!, Tips: Check for output ports that are connected to the same signal"
+            )
+
     def __str__(self) -> str:
         """Return a string representation of this SFG."""
         string_io = StringIO()
@@ -639,6 +652,12 @@ class SFG(AbstractOperation):
                 signal.remove_source()
                 signal.set_source(component.output(index_out))
 
+        if component_copy.type_name() == 'out':
+            sfg_copy._output_operations.remove(component_copy)
+            warnings.warn(f"Output port {component_copy.graph_id} has been removed")
+        if component.type_name() == 'out':
+            sfg_copy._output_operations.append(component)
+
         return sfg_copy()  # Copy again to update IDs.
 
     def insert_operation(
@@ -2097,3 +2116,9 @@ class SFG(AbstractOperation):
         ret = list({op.type_name() for op in self.operations})
         ret.sort()
         return ret
+
+    def get_used_graph_ids(self) -> Set[GraphID]:
+        """Get a list of all GraphID:s used in the SFG."""
+        ret = set({op.graph_id for op in self.operations})
+        sorted(ret)
+        return ret
diff --git a/test/test_core_operations.py b/test/test_core_operations.py
index f9d13fd6..6ed00f45 100644
--- a/test/test_core_operations.py
+++ b/test/test_core_operations.py
@@ -3,7 +3,6 @@
 import pytest
 
 from b_asic import (
-    SFG,
     Absolute,
     Addition,
     AddSub,
@@ -415,9 +414,8 @@ class TestSink:
         sfg = bfly.to_sfg()
         s = Sink()
         sfg1 = sfg.replace_operation(s, "out0")
-        sfg2 = SFG(sfg1.input_operations, sfg1.output_operations[1:])
 
-        assert sfg2.output_count == 1
-        assert sfg2.input_count == 2
+        assert sfg1.output_count == 1
+        assert sfg1.input_count == 2
 
-        assert sfg.evaluate_output(1, [0, 1]) == sfg2.evaluate_output(0, [0, 1])
+        assert sfg.evaluate_output(1, [0, 1]) == sfg1.evaluate_output(0, [0, 1])
diff --git a/test/test_sfg.py b/test/test_sfg.py
index eae9c403..2c1b8f11 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1351,8 +1351,10 @@ class TestSFGErrors:
         adaptor = SymmetricTwoportAdaptor(0.5, in1, in2)
         out1 = Output(adaptor.output(0))
         out2 = Output()
-        # No error, should be
-        SFG([in1, in2], [out1, out2])
+        with pytest.raises(
+            ValueError, match="At least one output operation is not connected!, Tips: Check for output ports that are connected to the same signal"
+        ):
+            SFG([in1, in2], [out1, out2])
 
     def test_unconnected_input(self):
         in1 = Input()
@@ -1421,8 +1423,10 @@ class TestSFGErrors:
         adaptor = SymmetricTwoportAdaptor(0.5, in1, in2)
         out1 = Output(adaptor.output(0))
         signal = Signal(adaptor.output(1))
-        # Should raise?
-        SFG([in1, in2], [out1], output_signals=[signal, signal])
+        with pytest.raises(
+            ValueError, match="At least one output operation is not connected!, Tips: Check for output ports that are connected to the same signal"
+        ):
+            SFG([in1, in2], [out1], output_signals=[signal, signal])
 
     def test_dangling_input_signal(self):
         in1 = Input()
@@ -1745,6 +1749,36 @@ class TestGetUsedTypeNames:
         assert sfg.get_used_type_names() == ['add', 'c', 'out']
 
 
+class Test_Keep_GraphIDs:
+    def test_single_accumulator(self):
+
+        i = Input()
+        d = Delay()
+        o = Output(d)
+        c = ConstantMultiplication(0.5, d)
+        a = Addition(i, c)
+        d.input(0).connect(a)
+
+        sfg = SFG([i], [o])
+        sfg = sfg.insert_operation_before('t0', ConstantMultiplication(8))
+        sfg = sfg.insert_operation_after('t0', ConstantMultiplication(8))
+        sfg = sfg.insert_operation(ConstantMultiplication(8), 't0')
+        assert sfg.get_used_graph_ids() == {
+            'add0',
+            'cmul0',
+            'cmul1',
+            'cmul2',
+            'cmul3',
+            'in0',
+            'out0',
+            't0',
+        }
+
+    def test_large_operation_tree(self, large_operation_tree):
+        sfg = SFG(outputs=[Output(large_operation_tree)])
+        assert sfg.get_used_type_names() == ['add', 'c', 'out']
+
+
 class TestInsertDelays:
     def test_insert_delays_before_operation(self):
         in1 = Input()
-- 
GitLab