Skip to content
Snippets Groups Projects
test_operation.py 2.85 KiB
Newer Older
  • Learn to ignore specific revisions
  • from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly
    
    from b_asic.signal import Signal
    from b_asic.port import InputPort, OutputPort
    
    import pytest
    
    
    class TestTraverse:
        def test_traverse_single_tree(self, operation):
            """Traverse a tree consisting of one operation."""
            constant = Constant(None)
            assert list(constant.traverse()) == [constant]
    
        def test_traverse_tree(self, operation_tree):
            """Traverse a basic addition tree with two constants."""
            assert len(list(operation_tree.traverse())) == 3
    
        def test_traverse_large_tree(self, large_operation_tree):
            """Traverse a larger tree."""
            assert len(list(large_operation_tree.traverse())) == 7
    
        def test_traverse_type(self, large_operation_tree):
            traverse = list(large_operation_tree.traverse())
    
            assert len(
                list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3
            assert len(
                list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4
    
    
        def test_traverse_loop(self, operation_tree):
            add_oper_signal = Signal()
            operation_tree._output_ports[0].add_signal(add_oper_signal)
            operation_tree._input_ports[0].remove_signal(add_oper_signal)
            operation_tree._input_ports[0].add_signal(add_oper_signal)
            assert len(list(operation_tree.traverse())) == 2
    
    
    
    class TestEvaluateOutput:
        def test_evaluate_output_two_real_inputs(self):
            """Test evaluate_output for two real numbered inputs."""
            add1 = Addition()
    
            assert list(add1.evaluate_output(0, [1, 2])) == [3]
    
        def test_evaluate_output_addition_two_complex_inputs(self):
            """Test evaluate_output for two complex numbered inputs."""
            add1 = Addition()
    
            assert list(add1.evaluate_output(0, [1+1j, 2])) == [3+1j]
    
        def test_evaluate_output_one_real_input(self):
            """Test evaluate_output for one real numbered inputs."""
            c_add1 = ConstantAddition(5)
    
            assert list(c_add1.evaluate_output(0, [1])) == [6]
    
        def test_evaluate_output_one_complex_input(self):
            """Test evaluate_output for one complex numbered inputs."""
            c_add1 = ConstantAddition(5)
    
            assert list(c_add1.evaluate_output(0, [1+1j])) == [6+1j]
    
        def test_evaluate_output_two_real_inputs_two_outputs(self):
            """Test evaluate_output for two real inputs and two outputs."""
            bfly1 = Butterfly()
    
            assert list(bfly1.evaluate_output(0, [6, 9])) == [15, -3]
            assert list(bfly1.evaluate_output(1, [6, 9])) == [15, -3]
    
        def test_evaluate_output_two_complex_inputs_two_outputs(self):
            """Test evaluate_output for two complex inputs and two outputs."""
            bfly1 = Butterfly()
    
            assert list(bfly1.evaluate_output(0, [3+2j, 4+2j])) == [7+4j, -1]
            assert list(bfly1.evaluate_output(1, [3+2j, 4+2j])) == [7+4j, -1]