import pytest from b_asic import SFG, Signal, Input, Output, Constant, ConstantMultiplication, Addition, Multiplication, Register, \ Butterfly, Subtraction class TestInit: def test_direct_input_to_output_sfg_construction(self): in1 = Input("IN1") out1 = Output(None, "OUT1") out1.input(0).connect(in1, "S1") sfg = SFG(inputs=[in1], outputs=[out1]) # in1 ---s1---> out1 assert len(list(sfg.components)) == 3 assert len(list(sfg.operations)) == 2 assert sfg.input_count == 1 assert sfg.output_count == 1 def test_same_signal_input_and_output_sfg_construction(self): add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") s1 = add2.input(0).connect(add1, "S1") # in1 ---s1---> out1 sfg = SFG(input_signals=[s1], output_signals=[s1]) assert len(list(sfg.components)) == 3 assert len(list(sfg.operations)) == 2 assert sfg.input_count == 1 assert sfg.output_count == 1 def test_outputs_construction(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) assert len(list(sfg.components)) == 7 assert len(list(sfg.operations)) == 4 assert sfg.input_count == 0 assert sfg.output_count == 1 def test_signals_construction(self, operation_tree): sfg = SFG(output_signals=[Signal(source=operation_tree.output(0))]) assert len(list(sfg.components)) == 7 assert len(list(sfg.operations)) == 4 assert sfg.input_count == 0 assert sfg.output_count == 1 class TestPrintSfg: def test_one_addition(self): inp1 = Input("INP1") inp2 = Input("INP2") add1 = Addition(inp1, inp2, "ADD1") out1 = Output(add1, "OUT1") sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") assert sfg.__str__() == \ "id: add1, name: ADD1, input: [s1, s2], output: [s3]\n" + \ "id: in1, name: INP1, input: [], output: [s1]\n" + \ "id: in2, name: INP2, input: [], output: [s2]\n" + \ "id: out1, name: OUT1, input: [s3], output: []\n" def test_add_mul(self): inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(inp1, inp2, "ADD1") mul1 = Multiplication(add1, inp3, "MUL1") out1 = Output(mul1, "OUT1") sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") assert sfg.__str__() == \ "id: add1, name: ADD1, input: [s1, s2], output: [s5]\n" + \ "id: in1, name: INP1, input: [], output: [s1]\n" + \ "id: in2, name: INP2, input: [], output: [s2]\n" + \ "id: mul1, name: MUL1, input: [s5, s3], output: [s4]\n" + \ "id: in3, name: INP3, input: [], output: [s3]\n" + \ "id: out1, name: OUT1, input: [s4], output: []\n" def test_constant(self): inp1 = Input("INP1") const1 = Constant(3, "CONST") add1 = Addition(const1, inp1, "ADD1") out1 = Output(add1, "OUT1") sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") assert sfg.__str__() == \ "id: add1, name: ADD1, input: [s3, s1], output: [s2]\n" + \ "id: c1, name: CONST, value: 3, input: [], output: [s3]\n" + \ "id: in1, name: INP1, input: [], output: [s1]\n" + \ "id: out1, name: OUT1, input: [s2], output: []\n" def test_simple_filter(self, simple_filter): assert simple_filter.__str__() == \ 'id: add1, name: , input: [s1, s3], output: [s4]\n' + \ 'id: in1, name: , input: [], output: [s1]\n' + \ 'id: cmul1, name: , input: [s5], output: [s3]\n' + \ 'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \ 'id: out1, name: , input: [s2], output: []\n' class TestDeepCopy: def test_deep_copy_no_duplicates(self): inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(inp1, inp2, "ADD1") mul1 = Multiplication(add1, inp3, "MUL1") out1 = Output(mul1, "OUT1") mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") mac_sfg_new = mac_sfg() assert mac_sfg.name == "mac_sfg" assert mac_sfg_new.name == "" for g_id, component in mac_sfg._components_by_id.items(): component_copy = mac_sfg_new.find_by_id(g_id) assert component.name == component_copy.name def test_deep_copy(self): inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") mul1 = Multiplication(None, None, "MUL1") out1 = Output(None, "OUT1") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S4") add2.input(1).connect(inp3, "S3") mul1.input(0).connect(add1, "S5") mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], id_number_offset=100, name="mac_sfg") mac_sfg_new = mac_sfg(name="mac_sfg2") assert mac_sfg.name == "mac_sfg" assert mac_sfg_new.name == "mac_sfg2" assert mac_sfg.id_number_offset == 100 assert mac_sfg_new.id_number_offset == 100 for g_id, component in mac_sfg._components_by_id.items(): component_copy = mac_sfg_new.find_by_id(g_id) assert component.name == component_copy.name def test_deep_copy_with_new_sources(self): inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(inp1, inp2, "ADD1") mul1 = Multiplication(add1, inp3, "MUL1") out1 = Output(mul1, "OUT1") mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") a = Addition(Constant(3), Constant(5)) b = Constant(2) mac_sfg_new = mac_sfg(a, b) assert mac_sfg_new.input(0).signals[0].source.operation is a assert mac_sfg_new.input(1).signals[0].source.operation is b class TestEvaluateOutput: def test_evaluate_output(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) assert sfg.evaluate_output(0, []) == 5 def test_evaluate_output_large(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) assert sfg.evaluate_output(0, []) == 14 def test_evaluate_output_cycle(self, operation_graph_with_cycle): sfg = SFG(outputs=[Output(operation_graph_with_cycle)]) with pytest.raises(Exception): sfg.evaluate_output(0, []) class TestComponents: def test_advanced_components(self): inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") mul1 = Multiplication(None, None, "MUL1") out1 = Output(None, "OUT1") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S4") add2.input(1).connect(inp3, "S3") mul1.input(0).connect(add1, "S5") mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") assert set([comp.name for comp in mac_sfg.components]) == { "INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} class TestReplaceComponents: def test_replace_addition_by_id(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) component_id = "add1" sfg = sfg.replace_component( Multiplication(name="Multi"), _id=component_id) assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() def test_replace_addition_by_component(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) component_id = "add1" component = sfg.find_by_id(component_id) sfg = sfg.replace_component(Multiplication( name="Multi"), _component=component) assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() def test_replace_addition_large_tree(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "add3" sfg = sfg.replace_component( Multiplication(name="Multi"), _id=component_id) assert "Multi" in sfg._components_by_name.keys() assert component_id not in sfg._components_by_id.keys() def test_replace_no_input_component(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) component_id = "c1" _const = sfg.find_by_id(component_id) sfg = sfg.replace_component(Constant(1), _id=component_id) assert _const is not sfg.find_by_id(component_id) def test_no_match_on_replace(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "addd1" try: sfg = sfg.replace_component( Multiplication(name="Multi"), _id=component_id) except AssertionError: assert True else: assert False def test_not_equal_input(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "c1" try: sfg = sfg.replace_component( Multiplication(name="Multi"), _id=component_id) except AssertionError: assert True else: assert False class TestFindComponentsWithTypeName: def test_mac_components(self): inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") mul1 = Multiplication(None, None, "MUL1") out1 = Output(None, "OUT1") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S4") add2.input(1).connect(inp3, "S3") mul1.input(0).connect(add1, "S5") mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") assert {comp.name for comp in mac_sfg.get_components_with_type_name( inp1.type_name())} == {"INP1", "INP2", "INP3"} assert {comp.name for comp in mac_sfg.get_components_with_type_name( add1.type_name())} == {"ADD1", "ADD2"} assert {comp.name for comp in mac_sfg.get_components_with_type_name( mul1.type_name())} == {"MUL1"} assert {comp.name for comp in mac_sfg.get_components_with_type_name( out1.type_name())} == {"OUT1"} assert {comp.name for comp in mac_sfg.get_components_with_type_name( Signal.type_name())} == {"S1", "S2", "S3", "S4", "S5", "S6", "S7"} class TestGetPrecedenceList: def test_inputs_registers(self): in1 = Input("IN1") c0 = ConstantMultiplication(5, in1, "C0") add1 = Addition(c0, None, "ADD1") # Not sure what operation "Q" is supposed to be in the example Q1 = ConstantMultiplication(1, add1, "Q1") T1 = Register(Q1, 0, "T1") T2 = Register(T1, 0, "T2") b2 = ConstantMultiplication(2, T2, "B2") b1 = ConstantMultiplication(3, T1, "B1") add2 = Addition(b1, b2, "ADD2") add1.input(1).connect(add2) a1 = ConstantMultiplication(4, T1, "A1") a2 = ConstantMultiplication(6, T2, "A2") add3 = Addition(a1, a2, "ADD3") a0 = ConstantMultiplication(7, Q1, "A0") add4 = Addition(a0, add3, "ADD4") out1 = Output(add4, "OUT1") sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") precedence_list = sfg.get_precedence_list() assert len(precedence_list) == 7 assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[0]]) == {"IN1", "T1", "T2"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[1]]) == {"C0", "B1", "B2", "A1", "A2"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[2]]) == {"ADD2", "ADD3"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[3]]) == {"ADD1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[4]]) == {"Q1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[5]]) == {"A0"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[6]]) == {"ADD4"} def test_inputs_constants_registers_multiple_outputs(self): in1 = Input("IN1") c0 = ConstantMultiplication(5, in1, "C0") add1 = Addition(c0, None, "ADD1") # Not sure what operation "Q" is supposed to be in the example Q1 = ConstantMultiplication(1, add1, "Q1") T1 = Register(Q1, 0, "T1") const1 = Constant(10, "CONST1") # Replace T2 register with a constant b2 = ConstantMultiplication(2, const1, "B2") b1 = ConstantMultiplication(3, T1, "B1") add2 = Addition(b1, b2, "ADD2") add1.input(1).connect(add2) a1 = ConstantMultiplication(4, T1, "A1") a2 = ConstantMultiplication(10, const1, "A2") add3 = Addition(a1, a2, "ADD3") a0 = ConstantMultiplication(7, Q1, "A0") # Replace ADD4 with a butterfly to test multiple output ports bfly1 = Butterfly(a0, add3, "BFLY1") out1 = Output(bfly1.output(0), "OUT1") out2 = Output(bfly1.output(1), "OUT2") sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") precedence_list = sfg.get_precedence_list() assert len(precedence_list) == 7 assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[0]]) == {"IN1", "T1", "CONST1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[1]]) == {"C0", "B1", "B2", "A1", "A2"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[2]]) == {"ADD2", "ADD3"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[3]]) == {"ADD1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[4]]) == {"Q1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[5]]) == {"A0"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[6]]) == {"BFLY1.0", "BFLY1.1"} def test_precedence_multiple_outputs_same_precedence(self, sfg_two_inputs_two_outputs): sfg_two_inputs_two_outputs.name = "NESTED_SFG" in1 = Input("IN1") sfg_two_inputs_two_outputs.input(0).connect(in1, "S1") in2 = Input("IN2") cmul1 = ConstantMultiplication(10, None, "CMUL1") cmul1.input(0).connect(in2, "S2") sfg_two_inputs_two_outputs.input(1).connect(cmul1, "S3") out1 = Output(sfg_two_inputs_two_outputs.output(0), "OUT1") out2 = Output(sfg_two_inputs_two_outputs.output(1), "OUT2") sfg = SFG(inputs=[in1, in2], outputs=[out1, out2]) precedence_list = sfg.get_precedence_list() assert len(precedence_list) == 3 assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[0]]) == {"IN1", "IN2"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[1]]) == {"CMUL1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[2]]) == {"NESTED_SFG.0", "NESTED_SFG.1"} def test_precedence_sfg_multiple_outputs_different_precedences(self, sfg_two_inputs_two_outputs_independent): sfg_two_inputs_two_outputs_independent.name = "NESTED_SFG" in1 = Input("IN1") in2 = Input("IN2") sfg_two_inputs_two_outputs_independent.input(0).connect(in1, "S1") cmul1 = ConstantMultiplication(10, None, "CMUL1") cmul1.input(0).connect(in2, "S2") sfg_two_inputs_two_outputs_independent.input(1).connect(cmul1, "S3") out1 = Output(sfg_two_inputs_two_outputs_independent.output(0), "OUT1") out2 = Output(sfg_two_inputs_two_outputs_independent.output(1), "OUT2") sfg = SFG(inputs=[in1, in2], outputs=[out1, out2]) precedence_list = sfg.get_precedence_list() assert len(precedence_list) == 3 assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[0]]) == {"IN1", "IN2"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[1]]) == {"NESTED_SFG.0", "CMUL1"} assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[2]]) == {"NESTED_SFG.1"} class TestDepends: def test_depends_sfg(self, sfg_two_inputs_two_outputs): assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(0)) == { 0, 1} assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(1)) == { 0, 1} def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): assert set( sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0)) == {0} assert set( sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1)) == {1} class TestConnectExternalSignalsToComponentsSoloComp: def test_connect_external_signals_to_components_mac(self): """ Replace a MAC with inner components in an SFG """ inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") mul1 = Multiplication(None, None, "MUL1") out1 = Output(None, "OUT1") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S3") add2.input(1).connect(inp3, "S4") mul1.input(0).connect(add1, "S5") mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1]) inp4 = Input("INP4") inp5 = Input("INP5") out2 = Output(None, "OUT2") mac_sfg.input(0).connect(inp4, "S8") mac_sfg.input(1).connect(inp5, "S9") out2.input(0).connect(mac_sfg.outputs[0], "S10") test_sfg = SFG(inputs=[inp4, inp5], outputs=[out2]) assert test_sfg.evaluate(1, 2) == 9 mac_sfg.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2) == 9 assert not test_sfg.connect_external_signals_to_components() def test_connect_external_signals_to_components_operation_tree(self, operation_tree): """ Replaces an SFG with only a operation_tree component with its inner components """ sfg1 = SFG(outputs=[Output(operation_tree)]) out1 = Output(None, "OUT1") out1.input(0).connect(sfg1.outputs[0], "S1") test_sfg = SFG(outputs=[out1]) assert test_sfg.evaluate_output(0, []) == 5 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate_output(0, []) == 5 assert not test_sfg.connect_external_signals_to_components() def test_connect_external_signals_to_components_large_operation_tree(self, large_operation_tree): """ Replaces an SFG with only a large_operation_tree component with its inner components """ sfg1 = SFG(outputs=[Output(large_operation_tree)]) out1 = Output(None, "OUT1") out1.input(0).connect(sfg1.outputs[0], "S1") test_sfg = SFG(outputs=[out1]) assert test_sfg.evaluate_output(0, []) == 14 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate_output(0, []) == 14 assert not test_sfg.connect_external_signals_to_components() class TestConnectExternalSignalsToComponentsMultipleComp: def test_connect_external_signals_to_components_operation_tree(self, operation_tree): """ Replaces a operation_tree in an SFG with other components """ sfg1 = SFG(outputs=[Output(operation_tree)]) inp1 = Input("INP1") inp2 = Input("INP2") out1 = Output(None, "OUT1") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S3") add2.input(1).connect(sfg1.outputs[0], "S4") out1.input(0).connect(add2, "S5") test_sfg = SFG(inputs=[inp1, inp2], outputs=[out1]) assert test_sfg.evaluate(1, 2) == 8 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2) == 8 assert not test_sfg.connect_external_signals_to_components() def test_connect_external_signals_to_components_large_operation_tree(self, large_operation_tree): """ Replaces a large_operation_tree in an SFG with other components """ sfg1 = SFG(outputs=[Output(large_operation_tree)]) inp1 = Input("INP1") inp2 = Input("INP2") out1 = Output(None, "OUT1") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S3") add2.input(1).connect(sfg1.outputs[0], "S4") out1.input(0).connect(add2, "S5") test_sfg = SFG(inputs=[inp1, inp2], outputs=[out1]) assert test_sfg.evaluate(1, 2) == 17 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2) == 17 assert not test_sfg.connect_external_signals_to_components() def create_sfg(self, op_tree): """ Create a simple SFG with either operation_tree or large_operation_tree """ sfg1 = SFG(outputs=[Output(op_tree)]) inp1 = Input("INP1") inp2 = Input("INP2") out1 = Output(None, "OUT1") add1 = Addition(None, None, "ADD1") add2 = Addition(None, None, "ADD2") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") add2.input(0).connect(add1, "S3") add2.input(1).connect(sfg1.outputs[0], "S4") out1.input(0).connect(add2, "S5") return SFG(inputs=[inp1, inp2], outputs=[out1]) def test_connect_external_signals_to_components_many_op(self, large_operation_tree): """ Replaces an sfg component in a larger SFG with several component operations """ inp1 = Input("INP1") inp2 = Input("INP2") inp3 = Input("INP3") inp4 = Input("INP4") out1 = Output(None, "OUT1") add1 = Addition(None, None, "ADD1") sub1 = Subtraction(None, None, "SUB1") add1.input(0).connect(inp1, "S1") add1.input(1).connect(inp2, "S2") sfg1 = self.create_sfg(large_operation_tree) sfg1.input(0).connect(add1, "S3") sfg1.input(1).connect(inp3, "S4") sub1.input(0).connect(sfg1.outputs[0], "S5") sub1.input(1).connect(inp4, "S6") out1.input(0).connect(sub1, "S7") test_sfg = SFG(inputs=[inp1, inp2, inp3, inp4], outputs=[out1]) assert test_sfg.evaluate(1, 2, 3, 4) == 16 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2, 3, 4) == 16 assert not test_sfg.connect_external_signals_to_components()