diff --git a/b_asic/resources.py b/b_asic/resources.py index c50007a75123ae533963141e0f7cce4e8252dca2..f6a0a1f0684a95cdd8eb6ab380035cefde93b077 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -187,10 +187,12 @@ class ProcessCollection: # Generate the life-time chart for i, process in enumerate(_sorted_nicely(self._collection)): bar_start = process.start_time % self._schedule_time + bar_end = process.start_time + process.execution_time bar_end = ( - process.start_time + process.execution_time - ) % self._schedule_time - bar_end = self._schedule_time if bar_end == 0 else bar_end + bar_end + if bar_end == self._schedule_time + else bar_end % self._schedule_time + ) if show_markers: _ax.scatter( x=bar_start, @@ -240,16 +242,84 @@ class ProcessCollection: _ax.set_ylim(0.25, len(self._collection) + 0.75) return _ax - def create_exclusion_graph_from_overlap( - self, add_name: bool = True + def create_exclusion_graph_from_ports( + self, + read_ports: Optional[int] = None, + write_ports: Optional[int] = None, + total_ports: Optional[int] = None, ) -> nx.Graph: """ - Generate exclusion graph based on processes overlapping in time + Create an exclusion graph from a ProcessCollection based on a number of read/write ports Parameters ---------- - add_name : bool, default: True - Add name of all processes as a node attribute in the exclusion graph. + read_ports : int + The number of read ports used when splitting process collection based on memory variable access. + write_ports : int + The number of write ports used when splitting process collection based on memory variable access. + total_ports : int + The total number of ports used when splitting process collection based on memory variable access. + + Returns + ------- + nx.Graph + + """ + if total_ports is None: + if read_ports is None or write_ports is None: + raise ValueError( + "If total_ports is unset, both read_ports and write_ports" + " must be provided." + ) + else: + total_ports = read_ports + write_ports + else: + read_ports = total_ports if read_ports is None else read_ports + write_ports = total_ports if write_ports is None else write_ports + + # Guard for proper read/write port settings + if read_ports != 1 or write_ports != 1: + raise ValueError( + "Splitting with read and write ports not equal to one with the" + " graph coloring heuristic does not make sense." + ) + if total_ports not in (1, 2): + raise ValueError( + "Total ports should be either 1 (non-concurrent reads/writes)" + " or 2 (concurrent read/writes) for graph coloring heuristic." + ) + + # Create new exclusion graph. Nodes are Processes + exclusion_graph = nx.Graph() + exclusion_graph.add_nodes_from(self._collection) + for node1 in exclusion_graph: + for node2 in exclusion_graph: + if node1 == node2: + continue + else: + node1_stop_time = node1.start_time + node1.execution_time + node2_stop_time = node2.start_time + node2.execution_time + if total_ports == 1: + # Single-port assignment + if node1.start_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + elif node1.start_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + else: + # Dual-port assignment + if node1.start_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + return exclusion_graph + + def create_exclusion_graph_from_execution_time(self) -> nx.Graph: + """ + Generate exclusion graph based on processes overlapping in time Returns ------- @@ -279,7 +349,47 @@ class ProcessCollection: exclusion_graph.add_edge(process1, process2) return exclusion_graph - def split( + def split_execution_time( + self, heuristic: str = "graph_color", coloring_strategy: str = "DSATUR" + ) -> Set["ProcessCollection"]: + """ + Split a ProcessCollection based on overlapping execution time. + + Parameters + ---------- + heuristic : str, default: 'graph_color' + The heuristic used when splitting based on execution times. + One of: 'graph_color', 'left_edge'. + coloring_strategy: str, default: 'DSATUR' + Node ordering strategy passed to nx.coloring.greedy_color() if the heuristic is set to 'graph_color'. This + parameter is only considered if heuristic is set to graph_color. + One of + * `'largest_first'` + * `'random_sequential'` + * `'smallest_last'` + * `'independent_set'` + * `'connected_sequential_bfs'` + * `'connected_sequential_dfs'` + * `'connected_sequential'` (alias for the previous strategy) + * `'saturation_largest_first'` + * `'DSATUR'` (alias for the saturation_largest_first strategy) + + Returns + ------- + A set of new ProcessCollection objects with the process splitting. + """ + if heuristic == "graph_color": + exclusion_graph = self.create_exclusion_graph_from_execution_time() + coloring = nx.coloring.greedy_color( + exclusion_graph, strategy=coloring_strategy + ) + return self._split_from_graph_coloring(coloring) + elif heuristic == "left_edge": + raise NotImplementedError() + else: + raise ValueError(f"Invalid heuristic '{heuristic}'") + + def split_ports( self, heuristic: str = "graph_color", read_ports: Optional[int] = None, @@ -309,77 +419,79 @@ class ProcessCollection: """ if total_ports is None: if read_ports is None or write_ports is None: - raise ValueError("inteligent quote") + raise ValueError( + "If total_ports is unset, both read_ports and write_ports" + " must be provided." + ) else: total_ports = read_ports + write_ports else: read_ports = total_ports if read_ports is None else read_ports write_ports = total_ports if write_ports is None else write_ports - if heuristic == "graph_color": - return self._split_graph_color( + return self._split_ports_graph_color( read_ports, write_ports, total_ports ) else: - raise ValueError("Invalid heuristic provided") + raise ValueError("Invalid heuristic provided.") - def _split_graph_color( - self, read_ports: int, write_ports: int, total_ports: int + def _split_ports_graph_color( + self, + read_ports: int, + write_ports: int, + total_ports: int, + coloring_strategy: str = "DSATUR", ) -> Set["ProcessCollection"]: """ Parameters ---------- - read_ports : int, optional + read_ports : int The number of read ports used when splitting process collection based on memory variable access. - write_ports : int, optional + write_ports : int The number of write ports used when splitting process collection based on memory variable access. - total_ports : int, optional + total_ports : int The total number of ports used when splitting process collection based on memory variable access. + coloring_strategy: str, default: 'DSATUR' + Node ordering strategy passed to nx.coloring.greedy_color() + One of + * `'largest_first'` + * `'random_sequential'` + * `'smallest_last'` + * `'independent_set'` + * `'connected_sequential_bfs'` + * `'connected_sequential_dfs'` + * `'connected_sequential'` (alias for the previous strategy) + * `'saturation_largest_first'` + * `'DSATUR'` (alias for the saturation_largest_first strategy) """ - if read_ports != 1 or write_ports != 1: - raise ValueError( - "Splitting with read and write ports not equal to one with the" - " graph coloring heuristic does not make sense." - ) - if total_ports not in (1, 2): - raise ValueError( - "Total ports should be either 1 (non-concurrent reads/writes)" - " or 2 (concurrent read/writes) for graph coloring heuristic." - ) - # Create new exclusion graph. Nodes are Processes - exclusion_graph = nx.Graph() - exclusion_graph.add_nodes_from(self._collection) + exclusion_graph = self.create_exclusion_graph_from_ports( + read_ports, write_ports, total_ports + ) - # Add exclusions (arcs) between processes in the exclusion graph - for node1 in exclusion_graph: - for node2 in exclusion_graph: - if node1 == node2: - continue - else: - node1_stop_time = node1.start_time + node1.execution_time - node2_stop_time = node2.start_time + node2.execution_time - if total_ports == 1: - # Single-port assignment - if node1.start_time == node2.start_time: - exclusion_graph.add_edge(node1, node2) - elif node1_stop_time == node2_stop_time: - exclusion_graph.add_edge(node1, node2) - elif node1.start_time == node2_stop_time: - exclusion_graph.add_edge(node1, node2) - elif node1_stop_time == node2.start_time: - exclusion_graph.add_edge(node1, node2) - else: - # Dual-port assignment - if node1.start_time == node2.start_time: - exclusion_graph.add_edge(node1, node2) - elif node1_stop_time == node2_stop_time: - exclusion_graph.add_edge(node1, node2) + # Perform assignment from coloring and return result + coloring = nx.coloring.greedy_color( + exclusion_graph, strategy=coloring_strategy + ) + return self._split_from_graph_coloring(coloring) + + def _split_from_graph_coloring( + self, + coloring: Dict[Process, int], + ) -> Set["ProcessCollection"]: + """ + Split :class:`Process` objects into a set of :class:`ProcessesCollection` objects based on a provided graph coloring. + Resulting :class:`ProcessCollection` will have the same schedule time and cyclic propoery as self. + + Parameters + ---------- + coloring : Dict[Process, int] + Process->int (color) mappings - # Perform assignment - coloring = nx.coloring.greedy_color(exclusion_graph) - draw_exclusion_graph_coloring(exclusion_graph, coloring) - # process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1) + Returns + ------- + A set of new ProcessCollections. + """ process_collection_set_list = [ set() for _ in range(max(coloring.values()) + 1) ] diff --git a/test/fixtures/interleaver-two-port-issue175.p b/test/fixtures/interleaver-two-port-issue175.p new file mode 100644 index 0000000000000000000000000000000000000000..f62ee38cbca0f1ec5ff14ddc073b58976cb69ad1 Binary files /dev/null and b/test/fixtures/interleaver-two-port-issue175.p differ diff --git a/test/test_resources.py b/test/test_resources.py index 38dfc2457010954fada118827aac23682638e047..581474af6b3f8ed8e9e5661a7219fad2d7bc2a9a 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -1,3 +1,5 @@ +import pickle + import matplotlib.pyplot as plt import networkx as nx import pytest @@ -6,7 +8,7 @@ from b_asic.research.interleaver import ( generate_matrix_transposer, generate_random_interleaver, ) -from b_asic.resources import draw_exclusion_graph_coloring +from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring class TestProcessCollectionPlainMemoryVariable: @@ -16,40 +18,40 @@ class TestProcessCollectionPlainMemoryVariable: simple_collection.draw_lifetime_chart(ax=ax, show_markers=False) return fig - def test_draw_proces_collection(self, simple_collection): - _, ax = plt.subplots(1, 2) - simple_collection.draw_lifetime_chart(ax=ax[0]) - exclusion_graph = ( - simple_collection.create_exclusion_graph_from_overlap() - ) - color_dict = nx.coloring.greedy_color(exclusion_graph) - draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1]) - - def test_split_memory_variable(self, simple_collection): - collection_split = simple_collection.split( - read_ports=1, write_ports=1, total_ports=2 - ) - assert len(collection_split) == 3 - @pytest.mark.mpl_image_compare(style='mpl20') def test_draw_matrix_transposer_4(self): fig, ax = plt.subplots() generate_matrix_transposer(4).draw_lifetime_chart(ax=ax) return fig + def test_split_memory_variable(self, simple_collection: ProcessCollection): + collection_split = simple_collection.split_ports( + heuristic="graph_color", read_ports=1, write_ports=1, total_ports=2 + ) + assert len(collection_split) == 3 + + # Issue: #175 + def test_interleaver_issue175(self): + with open('test/fixtures/interleaver-two-port-issue175.p', 'rb') as f: + interleaver_collection: ProcessCollection = pickle.load(f) + assert len(interleaver_collection.split_ports(total_ports=1)) == 2 + def test_generate_random_interleaver(self): - return for _ in range(10): for size in range(5, 20, 5): assert ( len( - generate_random_interleaver(size).split( + generate_random_interleaver(size).split_ports( read_ports=1, write_ports=1 ) ) == 1 ) assert ( - len(generate_random_interleaver(size).split(total_ports=1)) + len( + generate_random_interleaver(size).split_ports( + total_ports=1 + ) + ) == 2 )