""" B-ASIC test suite for the AbstractOperation class. """ import re import pytest from b_asic import ( MAD, Addition, Butterfly, Constant, ConstantMultiplication, Division, Multiplication, Reciprocal, SquareRoot, Subtraction, ) class TestOperationOverloading: def test_addition_overload(self): """Tests addition overloading for both operation and number argument. """ add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") add3 = add1 + add2 assert isinstance(add3, Addition) assert add3.input(0).signals == add1.output(0).signals assert add3.input(1).signals == add2.output(0).signals add4 = add3 + 5 assert isinstance(add4, Addition) assert add4.input(0).signals == add3.output(0).signals assert add4.input(1).signals[0].source.operation.value == 5 add5 = 5 + add4 assert isinstance(add5, Addition) assert add5.input(0).signals[0].source.operation.value == 5 assert add5.input(1).signals == add4.output(0).signals def test_subtraction_overload(self): """Tests subtraction overloading for both operation and number argument. """ add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") sub1 = add1 - add2 assert isinstance(sub1, Subtraction) assert sub1.input(0).signals == add1.output(0).signals assert sub1.input(1).signals == add2.output(0).signals sub2 = sub1 - 5 assert isinstance(sub2, Subtraction) assert sub2.input(0).signals == sub1.output(0).signals assert sub2.input(1).signals[0].source.operation.value == 5 sub3 = 5 - sub2 assert isinstance(sub3, Subtraction) assert sub3.input(0).signals[0].source.operation.value == 5 assert sub3.input(1).signals == sub2.output(0).signals def test_multiplication_overload(self): """Tests multiplication overloading for both operation and number argument. """ add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") mul1 = add1 * add2 assert isinstance(mul1, Multiplication) assert mul1.input(0).signals == add1.output(0).signals assert mul1.input(1).signals == add2.output(0).signals mul2 = mul1 * 5 assert isinstance(mul2, ConstantMultiplication) assert mul2.input(0).signals == mul1.output(0).signals assert mul2.value == 5 mul3 = 5 * mul2 assert isinstance(mul3, ConstantMultiplication) assert mul3.input(0).signals == mul2.output(0).signals assert mul3.value == 5 def test_division_overload(self): """Tests division overloading for both operation and number argument. """ add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") div1 = add1 / add2 assert isinstance(div1, Division) assert div1.input(0).signals == add1.output(0).signals assert div1.input(1).signals == add2.output(0).signals div2 = div1 / 5 assert isinstance(div2, Division) assert div2.input(0).signals == div1.output(0).signals assert div2.input(1).signals[0].source.operation.value == 5 div3 = 5 / div2 assert isinstance(div3, Division) assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(1).signals == div2.output(0).signals div4 = 1 / div3 assert isinstance(div4, Reciprocal) assert div4.input(0).signals == div3.output(0).signals class TestTraverse: def test_traverse_single_tree(self, operation): """Traverse a tree consisting of one operation.""" constant = Constant(None) assert list(constant.traverse()) == [constant] def test_traverse_tree(self, operation_tree): """Traverse a basic addition tree with two constants.""" assert len(list(operation_tree.traverse())) == 5 def test_traverse_large_tree(self, large_operation_tree): """Traverse a larger tree.""" assert len(list(large_operation_tree.traverse())) == 13 def test_traverse_type(self, large_operation_tree): result = list(large_operation_tree.traverse()) assert ( len( list(filter(lambda type_: isinstance(type_, Addition), result)) ) == 3 ) assert ( len( list(filter(lambda type_: isinstance(type_, Constant), result)) ) == 4 ) def test_traverse_loop(self, operation_graph_with_cycle): assert len(list(operation_graph_with_cycle.traverse())) == 8 class TestToSfg: def test_convert_mad_to_sfg(self): mad1 = MAD() mad1_sfg = mad1.to_sfg() assert mad1.evaluate(1, 1, 1) == mad1_sfg.evaluate(1, 1, 1) assert len(mad1_sfg.operations) == 6 def test_butterfly_to_sfg(self): but1 = Butterfly() but1_sfg = but1.to_sfg() assert but1.evaluate(1, 1)[0] == but1_sfg.evaluate(1, 1)[0] assert but1.evaluate(1, 1)[1] == but1_sfg.evaluate(1, 1)[1] assert len(but1_sfg.operations) == 8 def test_add_to_sfg(self): add1 = Addition() add1_sfg = add1.to_sfg() assert len(add1_sfg.operations) == 4 def test_sqrt_to_sfg(self): sqrt1 = SquareRoot() sqrt1_sfg = sqrt1.to_sfg() assert len(sqrt1_sfg.operations) == 3 class TestLatency: def test_latency_constructor(self): bfly = Butterfly(latency=5) assert bfly.latency == 5 assert bfly.latency_offsets == { "in0": 0, "in1": 0, "out0": 5, "out1": 5, } def test_latency_offsets_constructor(self): bfly = Butterfly( latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10} ) assert bfly.latency == 8 assert bfly.latency_offsets == { "in0": 2, "in1": 3, "out0": 5, "out1": 10, } def test_latency_and_latency_offsets_constructor(self): bfly = Butterfly(latency=5, latency_offsets={"in1": 2, "out0": 9}) assert bfly.latency == 9 assert bfly.latency_offsets == { "in0": 0, "in1": 2, "out0": 9, "out1": 5, } def test_set_latency(self): bfly = Butterfly() bfly.set_latency(9) assert bfly.latency == 9 assert bfly.latency_offsets == { "in0": 0, "in1": 0, "out0": 9, "out1": 9, } class TestExecutionTime: def test_execution_time_constructor(self): pass def test_set_execution_time(self): bfly = Butterfly() bfly.execution_time = 3 assert bfly.execution_time == 3 def test_set_execution_time_negative(self): bfly = Butterfly() with pytest.raises( ValueError, match="Execution time cannot be negative" ): bfly.execution_time = -1 class TestCopyOperation: def test_copy_butterfly_latency_offsets(self): bfly = Butterfly( latency_offsets={"in0": 4, "in1": 2, "out0": 10, "out1": 9} ) bfly_copy = bfly.copy_component() assert bfly_copy.latency_offsets == { "in0": 4, "in1": 2, "out0": 10, "out1": 9, } def test_copy_execution_time(self): add = Addition() add.execution_time = 2 add_copy = add.copy_component() assert add_copy.execution_time == 2 class TestPlotCoordinates: def test_simple_case(self): cmult = ConstantMultiplication(0.5) cmult.execution_time = 1 cmult.set_latency(3) lat, exe = cmult.get_plot_coordinates() assert lat == [[0, 0], [0, 1], [3, 1], [3, 0], [0, 0]] assert exe == [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]] def test_complicated_case(self): bfly = Butterfly( latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10} ) bfly.execution_time = 7 lat, exe = bfly.get_plot_coordinates() assert lat == [ [2, 0], [2, 0.5], [3, 0.5], [3, 1], [10, 1], [10, 0.5], [5, 0.5], [5, 0], [2, 0], ] assert exe == [[0, 0], [0, 1], [7, 1], [7, 0], [0, 0]] class TestIOCoordinates: def test_simple_case(self): cmult = ConstantMultiplication(0.5) cmult.execution_time = 1 cmult.set_latency(3) i_c, o_c = cmult.get_io_coordinates() assert i_c == [[0, 0.5]] assert o_c == [[3, 0.5]] def test_complicated_case(self): bfly = Butterfly( latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10} ) bfly.execution_time = 7 i_c, o_c = bfly.get_io_coordinates() assert i_c == [[2, 0.25], [3, 0.75]] assert o_c == [[5, 0.25], [10, 0.75]] class TestSplit: def test_simple_case(self): bfly = Butterfly() split = bfly.split() assert len(split) == 2 assert sum(isinstance(op, Addition) for op in split) == 1 assert sum(isinstance(op, Subtraction) for op in split) == 1 class TestLatencyOffset: def test_set_latency_offsets(self): bfly = Butterfly() bfly.set_latency_offsets({"in0": 3, "out1": 5}) assert bfly.latency_offsets == { "in0": 3, "in1": None, "out0": None, "out1": 5, } def test_set_latency_offsets_error(self): bfly = Butterfly() with pytest.raises( ValueError, match=re.escape( "Incorrectly formatted index in string, expected 'in' + index," " got: 'ina'" ), ): bfly.set_latency_offsets({"ina": 3, "out1": 5}) with pytest.raises( ValueError, match=re.escape( "Incorrectly formatted index in string, expected 'out' +" " index, got: 'outb'" ), ): bfly.set_latency_offsets({"in1": 3, "outb": 5}) with pytest.raises( ValueError, match=re.escape( "Incorrectly formatted string, expected 'in' + index or 'out'" " + index, got: 'foo'" ), ): bfly.set_latency_offsets({"foo": 3, "out2": 5})