From 1f09757ae0bb96132cef5b9083008d115bb9e2df Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Tue, 16 May 2023 11:41:51 +0200
Subject: [PATCH] Add cluster argument to Architecture Digraph

---
 b_asic/architecture.py | 32 +++++++++++++++++++++++++++-----
 1 file changed, 27 insertions(+), 5 deletions(-)

diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index 4ff70d38..45f1362c 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -736,13 +736,35 @@ of :class:`~b_asic.architecture.ProcessingElement`
             raise KeyError(f"{proc} not in {re_from.entity_name}")
         self._build_dicts()
 
-    def _digraph(self, branch_node=True) -> Digraph:
+    def _digraph(
+        self, branch_node=True, cluster=True, splines: str = "spline"
+    ) -> Digraph:
         dg = Digraph(node_attr={'shape': 'record'})
+        dg.attr(splines=splines)
         # Add nodes for memories and PEs to graph
-        for i, mem in enumerate(self._memories):
-            dg.node(mem.entity_name, mem._struct_def())
-        for i, pe in enumerate(self._processing_elements):
-            dg.node(pe.entity_name, pe._struct_def())
+        if cluster:
+            # Add subgraphs
+            if len(self._memories):
+                with dg.subgraph(name='cluster_0') as c:
+                    for i, mem in enumerate(self._memories):
+                        c.node(mem.entity_name, mem._struct_def())
+                    label = "Memory" if len(self._memories) <= 1 else "Memories"
+                    c.attr(label=label)
+            with dg.subgraph(name='cluster_1') as c:
+                for i, pe in enumerate(self._processing_elements):
+                    c.node(pe.entity_name, pe._struct_def())
+                label = (
+                    "Processing element"
+                    if len(self._processing_elements) <= 1
+                    else "Processing elements"
+                )
+                c.attr(label=label)
+
+        else:
+            for i, mem in enumerate(self._memories):
+                dg.node(mem.entity_name, mem._struct_def())
+            for i, pe in enumerate(self._processing_elements):
+                dg.node(pe.entity_name, pe._struct_def())
 
         # Create list of interconnects
         edges: DefaultDict[str, Set[Tuple[str, str]]] = defaultdict(set)
-- 
GitLab