Skip to content
Snippets Groups Projects
test_sfg.py 9.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • import pytest
    
    from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication
    
    
    
        def test_direct_input_to_output_sfg_construction(self):
            in1 = Input("IN1")
            out1 = Output(None, "OUT1")
            out1.input(0).connect(in1, "S1")
    
            sfg = SFG(inputs = [in1], outputs = [out1]) # in1 ---s1---> out1
    
            assert len(list(sfg.components)) == 3
            assert len(list(sfg.operations)) == 2
            assert sfg.input_count == 1
            assert sfg.output_count == 1
    
        def test_same_signal_input_and_output_sfg_construction(self):
            add1 = Addition(None, None, "ADD1")
            add2 = Addition(None, None, "ADD2")
    
            s1 = add2.input(0).connect(add1, "S1")
    
            sfg = SFG(input_signals = [s1], output_signals = [s1]) # in1 ---s1---> out1
    
            assert len(list(sfg.components)) == 3
            assert len(list(sfg.operations)) == 2
            assert sfg.input_count == 1
            assert sfg.output_count == 1
    
        def test_outputs_construction(self, operation_tree):
            sfg = SFG(outputs = [Output(operation_tree)])
    
            assert len(list(sfg.components)) == 7
            assert len(list(sfg.operations)) == 4
            assert sfg.input_count == 0
            assert sfg.output_count == 1
    
        def test_signals_construction(self, operation_tree):
            sfg = SFG(output_signals = [Signal(source = operation_tree.output(0))])
    
            assert len(list(sfg.components)) == 7
            assert len(list(sfg.operations)) == 4
            assert sfg.input_count == 0
            assert sfg.output_count == 1
    
    
    class TestPrintSfg:
        def test_one_addition(self):
            inp1 = Input("INP1")
            inp2 = Input("INP2")
            add1 = Addition(inp1, inp2, "ADD1")
            out1 = Output(add1, "OUT1")
            sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1")
    
            assert sfg.__str__() == \
                "id: add1, name: ADD1, input: [s1, s2], output: [s3]\n" + \
                "id: in1, name: INP1, input: [], output: [s1]\n" + \
                "id: in2, name: INP2, input: [], output: [s2]\n" + \
                "id: out1, name: OUT1, input: [s3], output: []\n"
    
        def test_add_mul(self):
            inp1 = Input("INP1")
            inp2 = Input("INP2")
            inp3 = Input("INP3")
            add1 = Addition(inp1, inp2, "ADD1")
            mul1 = Multiplication(add1, inp3, "MUL1")
            out1 = Output(mul1, "OUT1")
            sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg")
    
            assert sfg.__str__() == \
                "id: add1, name: ADD1, input: [s1, s2], output: [s5]\n" + \
                "id: in1, name: INP1, input: [], output: [s1]\n" + \
                "id: in2, name: INP2, input: [], output: [s2]\n" + \
                "id: mul1, name: MUL1, input: [s5, s3], output: [s4]\n" + \
                "id: in3, name: INP3, input: [], output: [s3]\n" + \
                "id: out1, name: OUT1, input: [s4], output: []\n"
    
        def test_constant(self):
            inp1 = Input("INP1")
            const1 = Constant(3, "CONST")
            add1 = Addition(const1, inp1, "ADD1")
            out1 = Output(add1, "OUT1")
    
            sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg")
    
            assert sfg.__str__() == \
                "id: add1, name: ADD1, input: [s3, s1], output: [s2]\n" + \
                "id: c1, name: CONST, value: 3, input: [], output: [s3]\n" + \
                "id: in1, name: INP1, input: [], output: [s1]\n" + \
                "id: out1, name: OUT1, input: [s2], output: []\n"
    
        def test_simple_filter(self, simple_filter):
             assert simple_filter.__str__() == \
                'id: add1, name: , input: [s1, s3], output: [s4]\n' + \
                'id: in1, name: , input: [], output: [s1]\n' + \
                'id: cmul1, name: , input: [s5], output: [s3]\n' + \
                'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \
                'id: out1, name: , input: [s2], output: []\n'
    
    
    class TestDeepCopy:
        def test_deep_copy_no_duplicates(self):
            inp1 = Input("INP1")
            inp2 = Input("INP2")
            inp3 = Input("INP3")
            add1 = Addition(inp1, inp2, "ADD1")
            mul1 = Multiplication(add1, inp3, "MUL1")
            out1 = Output(mul1, "OUT1")
    
            mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg")
            mac_sfg_new = mac_sfg()
    
            assert mac_sfg.name == "mac_sfg"
            assert mac_sfg_new.name == ""
    
            for g_id, component in mac_sfg._components_by_id.items():
                component_copy = mac_sfg_new.find_by_id(g_id)
                assert component.name == component_copy.name
    
        def test_deep_copy(self):
            inp1 = Input("INP1")
            inp2 = Input("INP2")
            inp3 = Input("INP3")
            add1 = Addition(None, None, "ADD1")
            add2 = Addition(None, None, "ADD2")
            mul1 = Multiplication(None, None, "MUL1")
            out1 = Output(None, "OUT1")
    
            add1.input(0).connect(inp1, "S1")
            add1.input(1).connect(inp2, "S2")
            add2.input(0).connect(add1, "S4")
            add2.input(1).connect(inp3, "S3")
            mul1.input(0).connect(add1, "S5")
            mul1.input(1).connect(add2, "S6")
            out1.input(0).connect(mul1, "S7")
    
            mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], id_number_offset = 100, name = "mac_sfg")
            mac_sfg_new = mac_sfg(name = "mac_sfg2")
    
            assert mac_sfg.name == "mac_sfg"
            assert mac_sfg_new.name == "mac_sfg2"
            assert mac_sfg.id_number_offset == 100
            assert mac_sfg_new.id_number_offset == 100
    
            for g_id, component in mac_sfg._components_by_id.items():
                component_copy = mac_sfg_new.find_by_id(g_id)
                assert component.name == component_copy.name
        
        def test_deep_copy_with_new_sources(self):
            inp1 = Input("INP1")
            inp2 = Input("INP2")
            inp3 = Input("INP3")
            add1 = Addition(inp1, inp2, "ADD1")
            mul1 = Multiplication(add1, inp3, "MUL1")
            out1 = Output(mul1, "OUT1")
    
            mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg")
    
            a = Addition(Constant(3), Constant(5))
            b = Constant(2)
            mac_sfg_new = mac_sfg(a, b)
            assert mac_sfg_new.input(0).signals[0].source.operation is a
            assert mac_sfg_new.input(1).signals[0].source.operation is b
    
    
    class TestEvaluateOutput:
        def test_evaluate_output(self, operation_tree):
            sfg = SFG(outputs = [Output(operation_tree)])
            assert sfg.evaluate_output(0, []) == 5
    
        def test_evaluate_output_large(self, large_operation_tree):
            sfg = SFG(outputs = [Output(large_operation_tree)])
            assert sfg.evaluate_output(0, []) == 14
    
        def test_evaluate_output_cycle(self, operation_graph_with_cycle):
            sfg = SFG(outputs = [Output(operation_graph_with_cycle)])
            with pytest.raises(Exception):
                sfg.evaluate_output(0, [])
    
    
    class TestComponents:
        def test_advanced_components(self):
            inp1 = Input("INP1")
            inp2 = Input("INP2")
            inp3 = Input("INP3")
            add1 = Addition(None, None, "ADD1")
            add2 = Addition(None, None, "ADD2")
            mul1 = Multiplication(None, None, "MUL1")
            out1 = Output(None, "OUT1")
    
            add1.input(0).connect(inp1, "S1")
            add1.input(1).connect(inp2, "S2")
            add2.input(0).connect(add1, "S4")
            add2.input(1).connect(inp3, "S3")
            mul1.input(0).connect(add1, "S5")
            mul1.input(1).connect(add2, "S6")
            out1.input(0).connect(mul1, "S7")
    
            mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg")
    
            assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"}
    
    
    
    class TestReplaceComponents:
        def test_replace_addition_by_id(self, operation_tree):
            sfg = SFG(outputs=[Output(operation_tree)])
            component_id = "add1"
    
            sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id)
            assert component_id not in sfg._components_by_id.keys()
            assert "Multi" in sfg._components_by_name.keys()
    
        def test_replace_addition_by_component(self, operation_tree):
            sfg = SFG(outputs=[Output(operation_tree)])
            component_id = "add1"
            component = sfg.find_by_id(component_id)
    
            sfg = sfg.replace_component(Multiplication(name="Multi"), _component=component)
            assert component_id not in sfg._components_by_id.keys()
            assert "Multi" in sfg._components_by_name.keys()
    
        def test_replace_addition_large_tree(self, large_operation_tree):
            sfg = SFG(outputs=[Output(large_operation_tree)])
            component_id = "add3"
    
            sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id)
            assert "Multi" in sfg._components_by_name.keys()
            assert component_id not in sfg._components_by_id.keys()
        
        def test_replace_no_input_component(self, operation_tree):
            sfg = SFG(outputs=[Output(operation_tree)])
            component_id = "c1"
            _const = sfg.find_by_id(component_id)
            
            sfg = sfg.replace_component(Constant(1), _id=component_id)
            assert _const is not sfg.find_by_id(component_id)
    
        def test_no_match_on_replace(self, large_operation_tree):
            sfg = SFG(outputs=[Output(large_operation_tree)])
            component_id = "addd1"
    
            try:
                sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id)
            except AssertionError:
                assert True
            else:
                assert False
    
        def test_not_equal_input(self, large_operation_tree):
            sfg = SFG(outputs=[Output(large_operation_tree)])
            component_id = "c1"
    
            try:
                sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id)
            except AssertionError:
                assert True
            else:
                assert False