Skip to content
Snippets Groups Projects
test_operation.py 1.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • import pytest
    
    
    from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot
    
    class TestTraverse:
        def test_traverse_single_tree(self, operation):
            """Traverse a tree consisting of one operation."""
            constant = Constant(None)
            assert list(constant.traverse()) == [constant]
    
        def test_traverse_tree(self, operation_tree):
            """Traverse a basic addition tree with two constants."""
    
            assert len(list(operation_tree.traverse())) == 5
    
    
        def test_traverse_large_tree(self, large_operation_tree):
            """Traverse a larger tree."""
    
            assert len(list(large_operation_tree.traverse())) == 13
    
    
        def test_traverse_type(self, large_operation_tree):
    
            result = list(large_operation_tree.traverse())
            assert len(list(filter(lambda type_: isinstance(type_, Addition), result))) == 3
            assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
    
        def test_traverse_loop(self, operation_graph_with_cycle):
    
            assert len(list(operation_graph_with_cycle.traverse())) == 8
    
    class TestToSfg:
        def test_convert_mad_to_sfg(self):
            mad1 = MAD()
            mad1_sfg = mad1.to_sfg()
    
            assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1)
            assert len(mad1_sfg.operations) == 6
    
        def test_butterfly_to_sfg(self):
            but1 = Butterfly()
            but1_sfg = but1.to_sfg()
    
            assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0]
            assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1]
            assert len(but1_sfg.operations) == 8
    
        def test_add_to_sfg(self):
            add1 = Addition()
            add1_sfg = add1.to_sfg()
    
            assert len(add1_sfg.operations) == 4
    
        def test_sqrt_to_sfg(self):
            sqrt1 = SquareRoot()
            sqrt1_sfg = sqrt1.to_sfg()
    
            assert len(sqrt1_sfg.operations) == 3