Skip to content
Snippets Groups Projects
test_operation.py 10.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    B-ASIC test suite for the AbstractOperation class.
    """
    
    import re
    
    import pytest
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
    from b_asic import (
        MAD,
        Addition,
        Butterfly,
        Constant,
        ConstantMultiplication,
        Division,
        Multiplication,
    
        Reciprocal,
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
        SquareRoot,
        Subtraction,
    )
    
    
    
    class TestOperationOverloading:
        def test_addition_overload(self):
    
            """Tests addition overloading for both operation and number argument.
            """
    
            add1 = Addition(None, None, "add1")
            add2 = Addition(None, None, "add2")
    
            add3 = add1 + add2
            assert isinstance(add3, Addition)
            assert add3.input(0).signals == add1.output(0).signals
            assert add3.input(1).signals == add2.output(0).signals
    
            add4 = add3 + 5
            assert isinstance(add4, Addition)
            assert add4.input(0).signals == add3.output(0).signals
            assert add4.input(1).signals[0].source.operation.value == 5
    
            add5 = 5 + add4
            assert isinstance(add5, Addition)
            assert add5.input(0).signals[0].source.operation.value == 5
            assert add5.input(1).signals == add4.output(0).signals
    
        def test_subtraction_overload(self):
    
            """Tests subtraction overloading for both operation and number argument.
            """
    
            add1 = Addition(None, None, "add1")
            add2 = Addition(None, None, "add2")
    
            sub1 = add1 - add2
            assert isinstance(sub1, Subtraction)
            assert sub1.input(0).signals == add1.output(0).signals
            assert sub1.input(1).signals == add2.output(0).signals
    
            sub2 = sub1 - 5
            assert isinstance(sub2, Subtraction)
            assert sub2.input(0).signals == sub1.output(0).signals
            assert sub2.input(1).signals[0].source.operation.value == 5
    
            sub3 = 5 - sub2
            assert isinstance(sub3, Subtraction)
            assert sub3.input(0).signals[0].source.operation.value == 5
            assert sub3.input(1).signals == sub2.output(0).signals
    
        def test_multiplication_overload(self):
    
            """Tests multiplication overloading for both operation and number argument.
            """
    
            add1 = Addition(None, None, "add1")
            add2 = Addition(None, None, "add2")
    
            mul1 = add1 * add2
            assert isinstance(mul1, Multiplication)
            assert mul1.input(0).signals == add1.output(0).signals
            assert mul1.input(1).signals == add2.output(0).signals
    
            mul2 = mul1 * 5
            assert isinstance(mul2, ConstantMultiplication)
            assert mul2.input(0).signals == mul1.output(0).signals
            assert mul2.value == 5
    
            mul3 = 5 * mul2
            assert isinstance(mul3, ConstantMultiplication)
            assert mul3.input(0).signals == mul2.output(0).signals
            assert mul3.value == 5
    
        def test_division_overload(self):
    
            """Tests division overloading for both operation and number argument.
            """
    
            add1 = Addition(None, None, "add1")
            add2 = Addition(None, None, "add2")
    
            div1 = add1 / add2
            assert isinstance(div1, Division)
            assert div1.input(0).signals == add1.output(0).signals
            assert div1.input(1).signals == add2.output(0).signals
    
            div2 = div1 / 5
            assert isinstance(div2, Division)
            assert div2.input(0).signals == div1.output(0).signals
            assert div2.input(1).signals[0].source.operation.value == 5
    
            div3 = 5 / div2
            assert isinstance(div3, Division)
            assert div3.input(0).signals[0].source.operation.value == 5
            assert div3.input(1).signals == div2.output(0).signals
    
    
            div4 = 1 / div3
            assert isinstance(div4, Reciprocal)
            assert div4.input(0).signals == div3.output(0).signals
    
    
    class TestTraverse:
        def test_traverse_single_tree(self, operation):
            """Traverse a tree consisting of one operation."""
            constant = Constant(None)
            assert list(constant.traverse()) == [constant]
    
        def test_traverse_tree(self, operation_tree):
            """Traverse a basic addition tree with two constants."""
    
            assert len(list(operation_tree.traverse())) == 5
    
    
        def test_traverse_large_tree(self, large_operation_tree):
            """Traverse a larger tree."""
    
            assert len(list(large_operation_tree.traverse())) == 13
    
    
        def test_traverse_type(self, large_operation_tree):
    
            result = list(large_operation_tree.traverse())
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert (
                len(
                    list(filter(lambda type_: isinstance(type_, Addition), result))
                )
                == 3
            )
            assert (
                len(
                    list(filter(lambda type_: isinstance(type_, Constant), result))
                )
                == 4
            )
    
    
        def test_traverse_loop(self, operation_graph_with_cycle):
            assert len(list(operation_graph_with_cycle.traverse())) == 8
    
    
    class TestToSfg:
        def test_convert_mad_to_sfg(self):
            mad1 = MAD()
            mad1_sfg = mad1.to_sfg()
    
            assert mad1.evaluate(1, 1, 1) == mad1_sfg.evaluate(1, 1, 1)
            assert len(mad1_sfg.operations) == 6
    
        def test_butterfly_to_sfg(self):
            but1 = Butterfly()
            but1_sfg = but1.to_sfg()
    
            assert but1.evaluate(1, 1)[0] == but1_sfg.evaluate(1, 1)[0]
            assert but1.evaluate(1, 1)[1] == but1_sfg.evaluate(1, 1)[1]
            assert len(but1_sfg.operations) == 8
    
        def test_add_to_sfg(self):
            add1 = Addition()
            add1_sfg = add1.to_sfg()
    
            assert len(add1_sfg.operations) == 4
    
        def test_sqrt_to_sfg(self):
            sqrt1 = SquareRoot()
            sqrt1_sfg = sqrt1.to_sfg()
    
            assert len(sqrt1_sfg.operations) == 3
    
    
    class TestLatency:
        def test_latency_constructor(self):
            bfly = Butterfly(latency=5)
    
            assert bfly.latency == 5
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert bfly.latency_offsets == {
                "in0": 0,
                "in1": 0,
                "out0": 5,
                "out1": 5,
            }
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            bfly = Butterfly(
                latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}
            )
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert bfly.latency_offsets == {
                "in0": 2,
                "in1": 3,
                "out0": 5,
                "out1": 10,
            }
    
    
        def test_latency_and_latency_offsets_constructor(self):
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            bfly = Butterfly(latency=5, latency_offsets={"in1": 2, "out0": 9})
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert bfly.latency_offsets == {
                "in0": 0,
                "in1": 2,
                "out0": 9,
                "out1": 5,
            }
    
    
        def test_set_latency(self):
            bfly = Butterfly()
    
            bfly.set_latency(9)
    
            assert bfly.latency == 9
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert bfly.latency_offsets == {
                "in0": 0,
                "in1": 0,
                "out0": 9,
                "out1": 9,
            }
    
    class TestExecutionTime:
        def test_execution_time_constructor(self):
            pass
    
        def test_set_execution_time(self):
            bfly = Butterfly()
            bfly.execution_time = 3
    
            assert bfly.execution_time == 3
    
    
        def test_set_execution_time_negative(self):
            bfly = Butterfly()
            with pytest.raises(
                ValueError, match="Execution time cannot be negative"
            ):
                bfly.execution_time = -1
    
    
        def test_copy_butterfly_latency_offsets(self):
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            bfly = Butterfly(
                latency_offsets={"in0": 4, "in1": 2, "out0": 10, "out1": 9}
            )
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert bfly_copy.latency_offsets == {
                "in0": 4,
                "in1": 2,
                "out0": 10,
                "out1": 9,
            }
    
    
        def test_copy_execution_time(self):
            add = Addition()
            add.execution_time = 2
    
            add_copy = add.copy_component()
    
            assert add_copy.execution_time == 2
    
    
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
    class TestPlotCoordinates:
    
        def test_simple_case(self):
            cmult = ConstantMultiplication(0.5)
            cmult.execution_time = 1
            cmult.set_latency(3)
    
            lat, exe = cmult.get_plot_coordinates()
            assert lat == [[0, 0], [0, 1], [3, 1], [3, 0], [0, 0]]
            assert exe == [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]
    
        def test_complicated_case(self):
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            bfly = Butterfly(
                latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}
            )
    
            bfly.execution_time = 7
    
            lat, exe = bfly.get_plot_coordinates()
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            assert lat == [
                [2, 0],
                [2, 0.5],
                [3, 0.5],
                [3, 1],
                [10, 1],
                [10, 0.5],
                [5, 0.5],
                [5, 0],
                [2, 0],
            ]
    
            assert exe == [[0, 0], [0, 1], [7, 1], [7, 0], [0, 0]]
    
    
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
    class TestIOCoordinates:
    
        def test_simple_case(self):
            cmult = ConstantMultiplication(0.5)
            cmult.execution_time = 1
            cmult.set_latency(3)
    
            i_c, o_c = cmult.get_io_coordinates()
            assert i_c == [[0, 0.5]]
            assert o_c == [[3, 0.5]]
    
        def test_complicated_case(self):
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            bfly = Butterfly(
                latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}
            )
    
            bfly.execution_time = 7
    
            i_c, o_c = bfly.get_io_coordinates()
            assert i_c == [[2, 0.25], [3, 0.75]]
            assert o_c == [[5, 0.25], [10, 0.75]]
    
    
    
    class TestSplit:
        def test_simple_case(self):
            bfly = Butterfly()
            split = bfly.split()
            assert len(split) == 2
            assert sum(isinstance(op, Addition) for op in split) == 1
            assert sum(isinstance(op, Subtraction) for op in split) == 1
    
    
    
    class TestLatencyOffset:
        def test_set_latency_offsets(self):
            bfly = Butterfly()
    
            bfly.set_latency_offsets({"in0": 3, "out1": 5})
    
            assert bfly.latency_offsets == {
                "in0": 3,
                "in1": None,
                "out0": None,
                "out1": 5,
            }
    
        def test_set_latency_offsets_error(self):
            bfly = Butterfly()
    
            with pytest.raises(
                ValueError,
                match=re.escape(
                    "Incorrectly formatted index in string, expected 'in' + index,"
                    " got: 'ina'"
                ),
            ):
                bfly.set_latency_offsets({"ina": 3, "out1": 5})
    
            with pytest.raises(
                ValueError,
                match=re.escape(
                    "Incorrectly formatted index in string, expected 'out' +"
                    " index, got: 'outb'"
                ),
            ):
                bfly.set_latency_offsets({"in1": 3, "outb": 5})
    
            with pytest.raises(
                ValueError,
                match=re.escape(
                    "Incorrectly formatted string, expected 'in' + index or 'out'"
                    " + index, got: 'foo'"
                ),
            ):
                bfly.set_latency_offsets({"foo": 3, "out2": 5})