diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index e8e7af01ab93fdba948d9ff7ec19078b3b71dee6..71a09eb0592b5dafe42a5a3e01d3b6ce1c1bd6c8 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -7,7 +7,7 @@ from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Set from numbers import Number from collections import defaultdict, deque -from b_asic.port import SignalSourceProvider, OutputPort +from b_asic.port import SignalSourceProvider, OutputPort, InputPort from b_asic.operation import Operation, AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName @@ -367,6 +367,9 @@ class SFG(AbstractOperation): # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) + def couple_operation(self, src: OutputPort, dest: InputPort): + pass + def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix if src_prefix: diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index fc8008fa4098ca488e23766f5ff7d05711300685..9ac5dec3e60aeb28b854d7c69ac114934c3ea8b0 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Addition, Constant, Signal +from b_asic import Addition, Constant, Signal, SquareRoot @pytest.fixture @@ -58,3 +58,25 @@ def operation_graph_with_cycle(): add1 = Addition(None, Constant(7)) add1.input(0).connect(add1) return Addition(add1, Constant(6)) + +@pytest.fixture +def large_operation_tree_one_input(): + """Addition operation connected with a large operation tree with 2 other additions and 3 constants + and one placeholder empty square root operation. + ----+ + | + v + add---+ + ^ | + | | + 3---+ v + add = ? + 4---+ ^ + | | + v | + add---+ + ^ + | + 5---+ + """ + return Addition(Addition(SquareRoot(), Constant(3)), Addition(Constant(4), Constant(5))) diff --git a/test/test_sfg.py b/test/test_sfg.py new file mode 100644 index 0000000000000000000000000000000000000000..d99767479c87af10da411dc0bc22f3139a20989e --- /dev/null +++ b/test/test_sfg.py @@ -0,0 +1,116 @@ +from b_asic import SFG, Operation +from b_asic.signal import Signal +from b_asic.core_operations import Addition, Constant, Multiplication +from b_asic.special_operations import Input, Output + + +class TestConstructor: + def test_direct_input_to_output_sfg_construction(self): + inp = Input("INP1") + out = Output(None, "OUT1") + out.input(0).connect(inp, "S1") + + sfg = SFG(inputs=[inp], outputs=[out]) + + assert len(list(sfg.components)) == 3 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + + def test_same_signal_input_and_output_sfg_construction(self): + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + sig1 = add2.input(0).connect(add1, "S1") + + sfg = SFG(input_signals=[sig1], output_signals=[sig1]) + + assert len(list(sfg.components)) == 3 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + + def test_outputs_construction(self, operation_tree): + outp = Output(operation_tree) + sfg = SFG(outputs=[outp]) + + assert len(list(sfg.components)) == 7 + assert sfg.input_count == 0 + assert sfg.output_count == 1 + + def test_signals_construction(self, operation_tree): + outs = Signal(source=operation_tree.output(0)) + sfg = SFG(output_signals=[outs]) + + assert len(list(sfg.components)) == 7 + assert sfg.input_count == 0 + assert sfg.output_count == 1 + + +class TestDeepCopy: + def test_deep_copy_no_duplicates(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + + mac_sfg = SFG(inputs=[inp1, inp2], + outputs=[out1], name="mac_sfg") + + mac_sfg_deep_copy = mac_sfg.deep_copy() + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_deep_copy.find_by_id(g_id) + assert component.name == component_copy.name + + def test_deep_copy(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs=[inp1, inp2], + outputs=[out1], name="mac_sfg") + + mac_sfg_deep_copy = mac_sfg.deep_copy() + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_deep_copy.find_by_id(g_id) + assert component.name == component_copy.name + + +class TestComponents: + + def test_advanced_components(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs=[inp1, inp2], + outputs=[out1], name="mac_sfg") + + assert set([comp.name for comp in mac_sfg.components]) == { + "INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} diff --git a/test/test_signal_flow_graph.py b/test/test_signal_flow_graph.py index 51267cc44ece60c05e622c02c89b8ec1a5d5b17d..8bd579b9f2a5471fe9bd2abaadb440363cb3dc87 100644 --- a/test/test_signal_flow_graph.py +++ b/test/test_signal_flow_graph.py @@ -1,6 +1,6 @@ import pytest -from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication +from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot class TestConstructor: @@ -147,3 +147,24 @@ class TestComponents: mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} + + +class TestCoupling: + + def test_couple_sfg_with_one_operation(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + _sqrt = SquareRoot(name="Sqrt") + sfg = sfg.couple_operation(sfg.output(0), _sqrt.input(0)) + assert "Sqrt" in sfg._components_by_name.keys() + + def test_couple_sfg_with_operation_tree(self, large_operation_tree, large_operation_tree_one_input): + sfg = SFG(outputs=[Output(large_operation_tree)]) + _sqrt = [comp for comp in large_operation_tree_one_input.traverse() if isinstance(comp, SquareRoot)][0] + sfg = sfg.couple_operation(sfg.output(0), _sqrt.input(0)) + assert "Sqrt" in sfg._components_by_name.keys() + + def test_couple_sfg_with_uneven_ports(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + _addition = Addition() + sfg = sfg.couple_operation(sfg.output(0), _addition.input(0)) + assert "Sqrt" in sfg._components_by_name.keys()