From bf6078f89a7406d10cbdde7617defad02f4ce051 Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Fri, 17 Feb 2023 22:21:04 +0100 Subject: [PATCH] Improve robustness for resource tests --- b_asic/resources.py | 41 ++++++++++++++++++++--------------------- test/test_resources.py | 20 ++++---------------- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/b_asic/resources.py b/b_asic/resources.py index f6a0a1f0..362e09dd 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -25,9 +25,7 @@ _T = TypeVar('_T') def _sorted_nicely(to_be_sorted: Iterable[_T]) -> List[_T]: """Sort the given iterable in the way that humans expect.""" convert = lambda text: int(text) if text.isdigit() else text - alphanum_key = lambda key: [ - convert(c) for c in re.split('([0-9]+)', str(key)) - ] + alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', str(key))] return sorted(to_be_sorted, key=alphanum_key) @@ -35,9 +33,7 @@ def draw_exclusion_graph_coloring( exclusion_graph: nx.Graph, color_dict: Dict[Process, int], ax: Optional[Axes] = None, - color_list: Optional[ - Union[List[str], List[Tuple[float, float, float]]] - ] = None, + color_list: Optional[Union[List[str], List[Tuple[float, float, float]]]] = None, ): """ Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment @@ -119,6 +115,13 @@ class ProcessCollection: self._schedule_time = schedule_time self._cyclic = cyclic + @property + def collection(self): + return self._collection + + def __len__(self): + return len(self.__collection__) + def add_process(self, process: Process): """ Add a new process to this process collection. @@ -174,9 +177,7 @@ class ProcessCollection: # Lifetime chart left and right padding PAD_L, PAD_R = 0.05, 0.05 - max_execution_time = max( - process.execution_time for process in self._collection - ) + max_execution_time = max(process.execution_time for process in self._collection) if max_execution_time > self._schedule_time: # Schedule time needs to be greater than or equal to the maximum process lifetime raise KeyError( @@ -429,9 +430,7 @@ class ProcessCollection: read_ports = total_ports if read_ports is None else read_ports write_ports = total_ports if write_ports is None else write_ports if heuristic == "graph_color": - return self._split_ports_graph_color( - read_ports, write_ports, total_ports - ) + return self._split_ports_graph_color(read_ports, write_ports, total_ports) else: raise ValueError("Invalid heuristic provided.") @@ -470,9 +469,7 @@ class ProcessCollection: ) # Perform assignment from coloring and return result - coloring = nx.coloring.greedy_color( - exclusion_graph, strategy=coloring_strategy - ) + coloring = nx.coloring.greedy_color(exclusion_graph, strategy=coloring_strategy) return self._split_from_graph_coloring(coloring) def _split_from_graph_coloring( @@ -492,15 +489,11 @@ class ProcessCollection: ------- A set of new ProcessCollections. """ - process_collection_set_list = [ - set() for _ in range(max(coloring.values()) + 1) - ] + process_collection_set_list = [set() for _ in range(max(coloring.values()) + 1)] for process, color in coloring.items(): process_collection_set_list[color].add(process) return { - ProcessCollection( - process_collection_set, self._schedule_time, self._cyclic - ) + ProcessCollection(process_collection_set, self._schedule_time, self._cyclic) for process_collection_set in process_collection_set_list } @@ -515,3 +508,9 @@ class ProcessCollection: fig.savefig(f, format="svg") return f.getvalue() + + def __repr__(self): + return ( + f"ProcessCollection({self._collection}, {self._schedule_time}," + f" {self._cyclic})" + ) diff --git a/test/test_resources.py b/test/test_resources.py index 581474af..18a84141 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -39,19 +39,7 @@ class TestProcessCollectionPlainMemoryVariable: def test_generate_random_interleaver(self): for _ in range(10): for size in range(5, 20, 5): - assert ( - len( - generate_random_interleaver(size).split_ports( - read_ports=1, write_ports=1 - ) - ) - == 1 - ) - assert ( - len( - generate_random_interleaver(size).split_ports( - total_ports=1 - ) - ) - == 2 - ) + collection = generate_random_interleaver(size) + assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1 + if any(var.execution_time for var in collection.collection): + assert len(collection.split_ports(total_ports=1)) == 2 -- GitLab