diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 4dfb5ef34340a60b55234c2aa6be6de306b5a713..6483cfc1476047bcbe897b871cc179990b894c4d 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -7,7 +7,7 @@ from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Mutabl from numbers import Number from collections import defaultdict, deque -from b_asic.port import SignalSourceProvider, OutputPort +from b_asic.port import SignalSourceProvider, OutputPort, InputPort from b_asic.operation import Operation, AbstractOperation, MutableOutputMap, MutableRegisterMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName @@ -480,23 +480,17 @@ class SFG(AbstractOperation): # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) - def replace_component(self, component: Operation, _component: Operation = None, _id: GraphID = None): + def replace_component(self, component: Operation, _id: GraphID): """Find and replace all components matching either on GraphID, Type or both. Then return a new deepcopy of the sfg with the replaced component. Arguments: component: The new component(s), e.g Multiplication - - Keyword arguments: - _component: The specific component to replace. _id: The GraphID to match the component to replace. """ - assert _component is not None or _id is not None, \ - "Define either operation to replace or GraphID of operation" - - if _id is not None: - _component = self.find_by_id(_id) + _sfg_copy = self() + _component = _sfg_copy.find_by_id(_id) assert _component is not None and isinstance(_component, Operation), \ "No operation matching the criteria found" @@ -516,7 +510,38 @@ class SFG(AbstractOperation): _signal.set_source(component.output(index_out)) # The old SFG will be deleted by Python GC - return self() + return _sfg_copy() + + def insert_operation(self, component: Operation, output_comp_id: GraphID): + """Insert an operation in the SFG after a given source operation. + The source operation output count must match the input count of the operation as well as the output + Then return a new deepcopy of the sfg with the inserted component. + + Arguments: + component: The new component, e.g Multiplication. + output_comp_id: The source operation GraphID to connect from. + """ + + # Preserve the original SFG by creating a copy. + sfg_copy = self() + output_comp = sfg_copy.find_by_id(output_comp_id) + if output_comp is None: + return None + + assert not isinstance(output_comp, Output), \ + "Source operation can not be an output operation." + assert len(output_comp.output_signals) == component.input_count, \ + "Source operation output count does not match input count for component." + assert len(output_comp.output_signals) == component.output_count, \ + "Destination operation input count does not match output for component." + + for index, signal_in in enumerate(output_comp.output_signals): + destination = signal_in.destination + signal_in.set_destination(component.input(index)) + destination.connect(component.output(index)) + + # Recreate the newly coupled SFG so that all attributes are correct. + return sfg_copy() def _evaluate_source(self, src: OutputPort, results: MutableOutputMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index fc8008fa4098ca488e23766f5ff7d05711300685..695979c65ab56eda3baf992b3b99963ee1fe7c9a 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Addition, Constant, Signal +from b_asic import Addition, Constant, Signal, Butterfly @pytest.fixture @@ -41,6 +41,41 @@ def large_operation_tree(): """ return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))) +@pytest.fixture +def large_operation_tree_names(): + """Valid addition operation connected with a large operation tree with 2 other additions and 4 constants. + With names. + 2---+ + | + v + add---+ + ^ | + | | + 3---+ v + add = (2 + 3) + (4 + 5) = 14 + 4---+ ^ + | | + v | + add---+ + ^ + | + 5---+ + """ + return Addition(Addition(Constant(2, name="constant2"), Constant(3, name="constant3")), Addition(Constant(4, name="constant4"), Constant(5, name="constant5"))) + +@pytest.fixture +def butterfly_operation_tree(): + """Valid butterfly operations connected to eachother with 3 butterfly operations and 2 constants as inputs and 2 outputs. + 2 ---+ +--- (2 + 4) ---+ +--- (6 + (-2)) ---+ +--- (4 + 8) ---> out1 = 12 + | | | | | | + v ^ v ^ v ^ + butterfly butterfly butterfly + ^ v ^ v ^ v + | | | | | | + 4 ---+ +--- (2 - 4) ---+ +--- (6 - (-2)) ---+ +--- (4 - 8) ---> out2 = -4 + """ + return Butterfly(*(Butterfly(*(Butterfly(Constant(2), Constant(4), name="bfly3").outputs), name="bfly2").outputs), name="bfly1") + @pytest.fixture def operation_graph_with_cycle(): """Invalid addition operation connected with an operation graph containing a cycle. diff --git a/test/test_sfg.py b/test/test_sfg.py index cf309c263d65848ab31b4b00e78ebf99132b8ad7..5f86739517b0d4c7bc9b242de24c3777222b51d2 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,7 +1,7 @@ import pytest from b_asic import SFG, Signal, Input, Output, Constant, ConstantMultiplication, Addition, Multiplication, Register, \ - Butterfly, Subtraction + Butterfly, Subtraction, SquareRoot class TestInit: @@ -217,16 +217,6 @@ class TestReplaceComponents: assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() - def test_replace_addition_by_component(self, operation_tree): - sfg = SFG(outputs=[Output(operation_tree)]) - component_id = "add1" - component = sfg.find_by_id(component_id) - - sfg = sfg.replace_component(Multiplication( - name="Multi"), _component=component) - assert component_id not in sfg._components_by_id.keys() - assert "Multi" in sfg._components_by_name.keys() - def test_replace_addition_large_tree(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "add3" @@ -269,6 +259,67 @@ class TestReplaceComponents: assert False +class TestInsertComponent: + + def test_insert_component_in_sfg(self, large_operation_tree_names): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) + sqrt = SquareRoot() + + _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) + assert _sfg.evaluate() != sfg.evaluate() + + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) + assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) + + assert not isinstance(sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) + assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) + + assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3") + assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3") + assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3") + + def test_insert_invalid_component_in_sfg(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + + # Should raise an exception for not matching input count to output count. + add4 = Addition() + with pytest.raises(Exception): + sfg.insert_operation(add4, "c1") + + def test_insert_at_output(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + + # Should raise an exception for trying to insert an operation after an output. + sqrt = SquareRoot() + with pytest.raises(Exception): + _sfg = sfg.insert_operation(sqrt, "out1") + + def test_insert_multiple_output_ports(self, butterfly_operation_tree): + sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs))) + _sfg = sfg.insert_operation(Butterfly(name="n_bfly"), "bfly3") + + assert sfg.evaluate() != _sfg.evaluate() + + assert len(sfg.find_by_name("n_bfly")) == 0 + assert len(_sfg.find_by_name("n_bfly")) == 1 + + # Correctly connected old output -> new input + assert _sfg.find_by_name("bfly3")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + + # Correctly connected new input -> old output + assert _sfg.find_by_name("n_bfly")[0].input(0).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] + assert _sfg.find_by_name("n_bfly")[0].input(1).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] + + # Correctly connected new output -> next input + assert _sfg.find_by_name("n_bfly")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + + # Correctly connected next input -> new output + assert _sfg.find_by_name("bfly2")[0].input(0).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly2")[0].input(1).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] + + class TestFindComponentsWithTypeName: def test_mac_components(self): inp1 = Input("INP1")