diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index 4ff70d382f56be8111892df3d72b360eab7b8bbf..45f1362c92721c72cdc974db6ae62cb5df3a193f 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -736,13 +736,35 @@ of :class:`~b_asic.architecture.ProcessingElement`
             raise KeyError(f"{proc} not in {re_from.entity_name}")
         self._build_dicts()
 
-    def _digraph(self, branch_node=True) -> Digraph:
+    def _digraph(
+        self, branch_node=True, cluster=True, splines: str = "spline"
+    ) -> Digraph:
         dg = Digraph(node_attr={'shape': 'record'})
+        dg.attr(splines=splines)
         # Add nodes for memories and PEs to graph
-        for i, mem in enumerate(self._memories):
-            dg.node(mem.entity_name, mem._struct_def())
-        for i, pe in enumerate(self._processing_elements):
-            dg.node(pe.entity_name, pe._struct_def())
+        if cluster:
+            # Add subgraphs
+            if len(self._memories):
+                with dg.subgraph(name='cluster_0') as c:
+                    for i, mem in enumerate(self._memories):
+                        c.node(mem.entity_name, mem._struct_def())
+                    label = "Memory" if len(self._memories) <= 1 else "Memories"
+                    c.attr(label=label)
+            with dg.subgraph(name='cluster_1') as c:
+                for i, pe in enumerate(self._processing_elements):
+                    c.node(pe.entity_name, pe._struct_def())
+                label = (
+                    "Processing element"
+                    if len(self._processing_elements) <= 1
+                    else "Processing elements"
+                )
+                c.attr(label=label)
+
+        else:
+            for i, mem in enumerate(self._memories):
+                dg.node(mem.entity_name, mem._struct_def())
+            for i, pe in enumerate(self._processing_elements):
+                dg.node(pe.entity_name, pe._struct_def())
 
         # Create list of interconnects
         edges: DefaultDict[str, Set[Tuple[str, str]]] = defaultdict(set)