Skip to content
Snippets Groups Projects
test_resources.py 10.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • Mikael Henriksson's avatar
    Mikael Henriksson committed
    import matplotlib.pyplot as plt
    import pytest
    
    
    from b_asic.core_operations import ConstantMultiplication
    
    from b_asic.process import PlainMemoryVariable
    
    from b_asic.research.interleaver import (
        generate_matrix_transposer,
        generate_random_interleaver,
    )
    
    from b_asic.resources import ProcessCollection, _ForwardBackwardTable
    
    Mikael Henriksson's avatar
    Mikael Henriksson committed
    
    
    class TestProcessCollectionPlainMemoryVariable:
        @pytest.mark.mpl_image_compare(style='mpl20')
        def test_draw_process_collection(self, simple_collection):
            fig, ax = plt.subplots()
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
            simple_collection.plot(ax=ax, show_markers=False)
    
    Mikael Henriksson's avatar
    Mikael Henriksson committed
            return fig
    
    
        @pytest.mark.mpl_image_compare(style='mpl20')
        def test_draw_matrix_transposer_4(self):
            fig, ax = plt.subplots()
    
            generate_matrix_transposer(4).plot(ax=ax)  # type: ignore
    
        def test_split_memory_variable(self, simple_collection: ProcessCollection):
    
            collection_split = simple_collection.split_on_ports(
    
                heuristic="graph_color", read_ports=1, write_ports=1, total_ports=2
            )
            assert len(collection_split) == 3
    
    
        @pytest.mark.mpl_image_compare(style='mpl20')
        def test_left_edge_cell_assignment(self, simple_collection: ProcessCollection):
            fig, ax = plt.subplots(1, 2)
    
            assignment = list(simple_collection._left_edge_assignment())
            for i, cell in enumerate(assignment):
                cell.plot(ax=ax[1], row=i)  # type: ignore
    
            simple_collection.plot(ax[0])  # type:ignore
    
        def test_cell_assignment_matrix_transposer(self):
            collection = generate_matrix_transposer(4, min_lifetime=5)
    
            assignment_left_edge = collection._left_edge_assignment()
            assignment_graph_color = collection.split_on_execution_time(
    
                coloring_strategy='saturation_largest_first'
            )
    
            assert len(assignment_left_edge) == 18
    
            assert len(assignment_graph_color) == 16
    
    
        def test_generate_memory_based_vhdl(self):
            for rows in [2, 3, 4, 5, 7]:
                collection = generate_matrix_transposer(rows, min_lifetime=0)
    
                assignment = collection.split_on_execution_time(heuristic="graph_color")
    
                collection.generate_memory_based_storage_vhdl(
    
                    filename=(
                        'b_asic/codegen/testbench/'
                        f'streaming_matrix_transposition_memory_{rows}x{rows}.vhdl'
                    ),
    
                    entity_name=f'streaming_matrix_transposition_memory_{rows}x{rows}',
                    assignment=assignment,
                    word_length=16,
                )
    
        def test_generate_register_based_vhdl(self):
            for rows in [2, 3, 4, 5, 7]:
                generate_matrix_transposer(
    
                ).generate_register_based_storage_vhdl(
    
                    filename=(
                        'b_asic/codegen/testbench/streaming_matrix_transposition_'
                        f'register_{rows}x{rows}.vhdl'
                    ),
    
                    entity_name=f'streaming_matrix_transposition_register_{rows}x{rows}',
                    word_length=16,
                )
    
        def test_rectangular_matrix_transposition(self):
            collection = generate_matrix_transposer(rows=4, cols=8, min_lifetime=2)
    
            assignment = collection.split_on_execution_time(heuristic="graph_color")
    
            collection.generate_memory_based_storage_vhdl(
    
                filename=(
                    'b_asic/codegen/testbench/streaming_matrix_transposition_memory_'
                    '4x8.vhdl'
                ),
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
                entity_name='streaming_matrix_transposition_memory_4x8',
    
                assignment=assignment,
                word_length=16,
            )
            collection.generate_register_based_storage_vhdl(
    
                filename=(
                    'b_asic/codegen/testbench/streaming_matrix_transposition_register_'
                    '4x8.vhdl'
                ),
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
                entity_name='streaming_matrix_transposition_register_4x8',
    
        def test_forward_backward_table_to_string(self):
            collection = ProcessCollection(
                collection={
                    PlainMemoryVariable(0, 0, {0: 5}, name="PC0"),
                    PlainMemoryVariable(1, 0, {0: 4}, name="PC1"),
                    PlainMemoryVariable(2, 0, {0: 3}, name="PC2"),
                    PlainMemoryVariable(3, 0, {0: 6}, name="PC3"),
                    PlainMemoryVariable(4, 0, {0: 6}, name="PC4"),
                    PlainMemoryVariable(5, 0, {0: 5}, name="PC5"),
                },
                schedule_time=7,
                cyclic=True,
            )
            t = _ForwardBackwardTable(collection)
            process_names = {match.group(0) for match in re.finditer(r'PC[0-9]+', str(t))}
            register_names = {match.group(0) for match in re.finditer(r'R[0-9]+', str(t))}
            assert len(process_names) == 6  # 6 process in the collection
            assert len(register_names) == 5  # 5 register required
            for i, process in enumerate(sorted(process_names)):
                assert process == f'PC{i}'
            for i, register in enumerate(sorted(register_names)):
                assert register == f'R{i}'
    
    
        def test_generate_random_interleaver(self):
    
            for _ in range(10):
                for size in range(5, 20, 5):
    
                    collection = generate_random_interleaver(size)
    
                    assert len(collection.split_on_ports(read_ports=1, write_ports=1)) == 1
    
                    if any(var.execution_time for var in collection.collection):
    
                        assert len(collection.split_on_ports(total_ports=1)) == 2
    
    Oscar Gustafsson's avatar
    Oscar Gustafsson committed
    
        def test_len_process_collection(self, simple_collection: ProcessCollection):
            assert len(simple_collection) == 7
    
    
        def test_get_by_type_name(self, secondorder_iir_schedule_with_execution_times):
            pc = secondorder_iir_schedule_with_execution_times.get_operations()
            pc_cmul = pc.get_by_type_name(ConstantMultiplication.type_name())
            assert len(pc_cmul) == 7
            assert all(
                isinstance(operand.operation, ConstantMultiplication)
                for operand in pc_cmul.collection
            )
    
    
        def test_show(self, simple_collection: ProcessCollection):
            simple_collection.show()
    
    
        def test_add_remove_process(self, simple_collection: ProcessCollection):
            new_proc = PlainMemoryVariable(1, 0, {0: 3})
            assert len(simple_collection) == 7
            assert new_proc not in simple_collection
    
            simple_collection.add_process(new_proc)
            assert len(simple_collection) == 8
            assert new_proc in simple_collection
    
            simple_collection.remove_process(new_proc)
            assert len(simple_collection) == 7
            assert new_proc not in simple_collection
    
    
        @pytest.mark.mpl_image_compare(style='mpl20')
        def test_max_min_lifetime_bar_plot(self):
            fig, ax = plt.subplots()
            collection = ProcessCollection(
                {
                    # Process starting exactly at scheudle start
                    PlainMemoryVariable(0, 0, {0: 0}, "S1"),
                    PlainMemoryVariable(0, 0, {0: 5}, "S2"),
                    # Process starting somewhere between schedule start and end
                    PlainMemoryVariable(2, 0, {0: 0}, "M1"),
                    PlainMemoryVariable(2, 0, {0: 5}, "M2"),
                    # Process starting at the schedule end
                    PlainMemoryVariable(5, 0, {0: 0}, "E1"),
                    PlainMemoryVariable(5, 0, {0: 5}, "E2"),
                },
                schedule_time=5,
            )
            collection.plot(ax)
            return fig
    
        def test_multiple_reads_exclusion_greaph(self):
            # Initial collection
            p0 = PlainMemoryVariable(0, 0, {0: 3}, 'P0')
            p1 = PlainMemoryVariable(1, 0, {0: 2}, 'P1')
            p2 = PlainMemoryVariable(2, 0, {0: 2}, 'P2')
            p3 = PlainMemoryVariable(3, 0, {0: 3}, 'P3')
            collection = ProcessCollection({p0, p1, p2, p3}, 5, cyclic=True)
            exclusion_graph = collection.create_exclusion_graph_from_ports(
                read_ports=1,
                write_ports=1,
                total_ports=1,
            )
            for p in [p0, p1, p2, p3]:
                assert p in exclusion_graph
            assert exclusion_graph.degree(p0) == 2
            assert exclusion_graph.degree(p1) == 2
            assert exclusion_graph.degree(p2) == 0
            assert exclusion_graph.degree(p3) == 2
    
            # Add multi-read process
            p4 = PlainMemoryVariable(0, 0, {0: 1, 1: 2, 2: 3, 3: 4}, 'P4')
            collection.add_process(p4)
            exclusion_graph = collection.create_exclusion_graph_from_ports(
                read_ports=1,
                write_ports=1,
                total_ports=1,
            )
            for p in [p0, p1, p2, p3, p4]:
                assert p in exclusion_graph
            assert exclusion_graph.degree(p0) == 3
            assert exclusion_graph.degree(p1) == 3
            assert exclusion_graph.degree(p2) == 1
            assert exclusion_graph.degree(p3) == 3
    
    
        def test_left_edge_maximum_lifetime(self):
            a = PlainMemoryVariable(2, 0, {0: 1}, "cmul1.0")
            b = PlainMemoryVariable(4, 0, {0: 7}, "cmul4.0")
            c = PlainMemoryVariable(5, 0, {0: 4}, "cmul5.0")
            collection = ProcessCollection([a, b, c], schedule_time=7, cyclic=True)
            for heuristic in ("graph_color", "left_edge"):
                assignment = collection.split_on_execution_time(heuristic)
                assert len(assignment) == 2
                a_idx = 0 if a in assignment[0] else 1
                assert b not in assignment[a_idx]
                assert c in assignment[a_idx]
    
        def test_split_on_execution_lifetime_assert(self):
            a = PlainMemoryVariable(3, 0, {0: 10}, "MV0")
            collection = ProcessCollection([a], schedule_time=9, cyclic=True)
            for heuristic in ("graph_color", "left_edge"):
                with pytest.raises(
                    ValueError,
                    match="MV0 has execution time greater than the schedule time",
                ):
                    collection.split_on_execution_time(heuristic)
    
    
        def test_split_on_length(self):
            # Test 1: Exclude a zero-time access time
            collection = ProcessCollection(
                collection=[PlainMemoryVariable(0, 1, {0: 1, 1: 2, 2: 3})],
                schedule_time=4,
            )
            short, long = collection.split_on_length(0)
            assert len(short) == 0 and len(long) == 1
            for split_time in [1, 2]:
                short, long = collection.split_on_length(split_time)
                assert len(short) == 1 and len(long) == 1
            short, long = collection.split_on_length(3)
            assert len(short) == 1 and len(long) == 0
    
            # Test 2: Include a zero-time access time
            collection = ProcessCollection(
                collection=[PlainMemoryVariable(0, 1, {0: 0, 1: 1, 2: 2, 3: 3})],
                schedule_time=4,
            )
            short, long = collection.split_on_length(0)
            assert len(short) == 1 and len(long) == 1
            for split_time in [1, 2]:
                short, long = collection.split_on_length(split_time)
                assert len(short) == 1 and len(long) == 1