From 8a028e82f6b6379b2e02ace621ad11b7f74869f6 Mon Sep 17 00:00:00 2001
From: Mikael Henriksson <mike.zx@hotmail.com>
Date: Wed, 30 Aug 2023 17:35:27 +0200
Subject: [PATCH] resources.py: add split_ports_sequentially() and left-edge
 based split_on_ports()

---
 b_asic/resources.py       | 102 ++++++++++++++++++++++++++++++++++++--
 test/test_architecture.py |  11 ++--
 test/test_resources.py    |  46 ++++++++++++++---
 3 files changed, 143 insertions(+), 16 deletions(-)

diff --git a/b_asic/resources.py b/b_asic/resources.py
index 25cc9f07..fa7c351f 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -1,6 +1,6 @@
 import io
 import re
-from collections import Counter
+from collections import Counter, defaultdict
 from functools import reduce
 from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union
 
@@ -886,7 +886,7 @@ class ProcessCollection:
 
     def split_on_ports(
         self,
-        heuristic: str = "graph_color",
+        heuristic: str = "left_edge",
         read_ports: Optional[int] = None,
         write_ports: Optional[int] = None,
         total_ports: Optional[int] = None,
@@ -903,7 +903,7 @@ class ProcessCollection:
             Valid options are:
 
             * "graph_color"
-            * "..."
+            * "left_edge"
 
         read_ports : int, optional
             The number of read ports used when splitting process collection based on
@@ -926,9 +926,105 @@ class ProcessCollection:
         )
         if heuristic == "graph_color":
             return self._split_ports_graph_color(read_ports, write_ports, total_ports)
+        elif heuristic == "left_edge":
+            return self.split_ports_sequentially(
+                read_ports,
+                write_ports,
+                total_ports,
+                sequence=sorted(self),
+            )
         else:
             raise ValueError("Invalid heuristic provided.")
 
+    def split_ports_sequentially(
+        self,
+        read_ports: int,
+        write_ports: int,
+        total_ports: int,
+        sequence: List[Process],
+    ) -> List["ProcessCollection"]:
+        """
+        Split this collection into multiple new collections by sequentially assigning
+        processes in the order of `sequence`.
+
+        This method takes the processes from `sequence`, in order, and assignes them to
+        to multiple new `ProcessCollection` based on port collisions in a first-come
+        first-served manner. The first `Process` in `sequence` is assigned first, and
+        the last `Proccess` in `sequence is assigned last.
+
+        Parameters
+        ----------
+        read_ports : int
+            The number of read ports used when splitting process collection based on
+            memory variable access.
+        write_ports : int
+            The number of write ports used when splitting process collection based on
+            memory variable access.
+        total_ports : int
+            The total number of ports used when splitting process collection based on
+            memory variable access.
+        sequence: list of `Process`
+            A list of the processes used to determine the order in which processes are
+            assigned.
+
+        Returns
+        -------
+        A set of new ProcessCollection objects with the process splitting.
+        """
+
+        def ports_collide(proc: Process, collection: ProcessCollection):
+            """
+            Predicate test if insertion of a process `proc` results in colliding ports
+            when inserted to `collection` based on the `read_ports`, `write_ports`, and
+            `total_ports`.
+            """
+
+            # Test the number of concurrent write accesses
+            collection_writes = defaultdict(int, collection.write_port_accesses())
+            if collection_writes[proc.start_time] >= write_ports:
+                return True
+
+            # Test the number of concurrent read accesses
+            collection_reads = defaultdict(int, collection.read_port_accesses())
+            for proc_read_time in proc.read_times:
+                if collection_reads[proc_read_time % self.schedule_time] >= read_ports:
+                    return True
+
+            # Test the number of total accesses
+            collection_total_accesses = defaultdict(
+                int, Counter(collection_writes) + Counter(collection_reads)
+            )
+            for access_time in [proc.start_time, *proc.read_times]:
+                if collection_total_accesses[access_time] >= total_ports:
+                    return True
+
+            # No collision detected
+            return False
+
+        # Make sure that processes from `sequence` and and `self` are equal
+        if set(self.collection) != set(sequence):
+            raise KeyError("processes in `sequence` must be equal to processes in self")
+
+        collections: List[ProcessCollection] = []
+        for process in sequence:
+            process_added = False
+            for collection in collections:
+                if not ports_collide(process, collection):
+                    collection.add_process(process)
+                    process_added = True
+                    break
+            if not process_added:
+                # Stuff the process in a new collection
+                collections.append(
+                    ProcessCollection(
+                        [process],
+                        schedule_time=self.schedule_time,
+                        cyclic=self._cyclic,
+                    )
+                )
+        # Return the list of created ProcessCollections
+        return collections
+
     def _split_ports_graph_color(
         self,
         read_ports: int,
diff --git a/test/test_architecture.py b/test/test_architecture.py
index 108734b4..2dac82ff 100644
--- a/test/test_architecture.py
+++ b/test/test_architecture.py
@@ -157,10 +157,7 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule):
 
     # Graph representation
     # Parts are non-deterministic, but this first part seems OK
-    s = (
-        'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tsubgraph'
-        ' cluster_memories'
-    )
+    s = 'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tsubgraph cluster_memories'
     assert architecture._digraph().source.startswith(s)
     s = 'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tMEM0'
     assert architecture._digraph(cluster=False).source.startswith(s)
@@ -229,9 +226,9 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule):
     architecture.move_process('in0.0', memories[1], memories[0])
     assert memories[0].collection.from_name('in0.0')
 
-    assert processing_elements[1].collection.from_name('add0')
-    architecture.move_process('add0', processing_elements[1], processing_elements[0])
     assert processing_elements[0].collection.from_name('add0')
+    architecture.move_process('add0', processing_elements[0], processing_elements[1])
+    assert processing_elements[1].collection.from_name('add0')
 
     # Processes leave the resources they have moved from
     with pytest.raises(KeyError):
@@ -239,7 +236,7 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule):
     with pytest.raises(KeyError):
         memories[1].collection.from_name('in0.0')
     with pytest.raises(KeyError):
-        processing_elements[1].collection.from_name('add0')
+        processing_elements[0].collection.from_name('add0')
 
     # Processes can only be moved when the source and destination process-types match
     with pytest.raises(TypeError, match="cmul3.0 not of type"):
diff --git a/test/test_resources.py b/test/test_resources.py
index aac845b2..8925f6e6 100644
--- a/test/test_resources.py
+++ b/test/test_resources.py
@@ -1,8 +1,8 @@
 import re
 
 import matplotlib.pyplot as plt
-import pytest
 import matplotlib.testing.decorators
+import pytest
 
 from b_asic.core_operations import ConstantMultiplication
 from b_asic.process import PlainMemoryVariable
@@ -14,25 +14,57 @@ from b_asic.resources import ProcessCollection, _ForwardBackwardTable
 
 
 class TestProcessCollectionPlainMemoryVariable:
-    @matplotlib.testing.decorators.image_comparison(['test_draw_process_collection.png'])
+    @matplotlib.testing.decorators.image_comparison(
+        ['test_draw_process_collection.png']
+    )
     def test_draw_process_collection(self, simple_collection):
         fig, ax = plt.subplots()
         simple_collection.plot(ax=ax, show_markers=False)
         return fig
 
-    @matplotlib.testing.decorators.image_comparison(['test_draw_matrix_transposer_4.png'])
+    @matplotlib.testing.decorators.image_comparison(
+        ['test_draw_matrix_transposer_4.png']
+    )
     def test_draw_matrix_transposer_4(self):
         fig, ax = plt.subplots()
         generate_matrix_transposer(4).plot(ax=ax)  # type: ignore
         return fig
 
-    def test_split_memory_variable(self, simple_collection: ProcessCollection):
+    def test_split_memory_variable_graph_color(
+        self, simple_collection: ProcessCollection
+    ):
         collection_split = simple_collection.split_on_ports(
             heuristic="graph_color", read_ports=1, write_ports=1, total_ports=2
         )
         assert len(collection_split) == 3
 
-    @matplotlib.testing.decorators.image_comparison(['test_left_edge_cell_assignment.png'])
+    def test_split_sequence_raises(self, simple_collection: ProcessCollection):
+        with pytest.raises(KeyError, match="processes in `sequence` must be"):
+            simple_collection.split_ports_sequentially(
+                read_ports=1, write_ports=1, total_ports=2, sequence=[]
+            )
+
+    def test_split_memory_variable_left_edge(
+        self, simple_collection: ProcessCollection
+    ):
+        split = simple_collection.split_on_ports(
+            heuristic="left_edge", read_ports=1, write_ports=1, total_ports=2
+        )
+        assert len(split) == 3
+
+        split = simple_collection.split_on_ports(
+            heuristic="left_edge", read_ports=1, write_ports=2, total_ports=2
+        )
+        assert len(split) == 3
+
+        split = simple_collection.split_on_ports(
+            heuristic="left_edge", read_ports=2, write_ports=2, total_ports=2
+        )
+        assert len(split) == 2
+
+    @matplotlib.testing.decorators.image_comparison(
+        ['test_left_edge_cell_assignment.png']
+    )
     def test_left_edge_cell_assignment(self, simple_collection: ProcessCollection):
         fig, ax = plt.subplots(1, 2)
         assignment = list(simple_collection._left_edge_assignment())
@@ -158,7 +190,9 @@ class TestProcessCollectionPlainMemoryVariable:
         assert len(simple_collection) == 7
         assert new_proc not in simple_collection
 
-    @matplotlib.testing.decorators.image_comparison(['test_max_min_lifetime_bar_plot.png'])
+    @matplotlib.testing.decorators.image_comparison(
+        ['test_max_min_lifetime_bar_plot.png']
+    )
     def test_max_min_lifetime_bar_plot(self):
         fig, ax = plt.subplots()
         collection = ProcessCollection(
-- 
GitLab