diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index 45f1362c92721c72cdc974db6ae62cb5df3a193f..c2c6dde38cf2b74760e5e2224193287fddf9e2c2 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -737,7 +737,11 @@ of :class:`~b_asic.architecture.ProcessingElement`
         self._build_dicts()
 
     def _digraph(
-        self, branch_node=True, cluster=True, splines: str = "spline"
+        self,
+        branch_node: bool = True,
+        cluster: bool = True,
+        splines: str = "spline",
+        io_cluster: bool = True,
     ) -> Digraph:
         dg = Digraph(node_attr={'shape': 'record'})
         dg.attr(splines=splines)
@@ -745,21 +749,31 @@ of :class:`~b_asic.architecture.ProcessingElement`
         if cluster:
             # Add subgraphs
             if len(self._memories):
-                with dg.subgraph(name='cluster_0') as c:
-                    for i, mem in enumerate(self._memories):
+                with dg.subgraph(name='cluster_memories') as c:
+                    for mem in self._memories:
                         c.node(mem.entity_name, mem._struct_def())
                     label = "Memory" if len(self._memories) <= 1 else "Memories"
                     c.attr(label=label)
-            with dg.subgraph(name='cluster_1') as c:
-                for i, pe in enumerate(self._processing_elements):
-                    c.node(pe.entity_name, pe._struct_def())
+            with dg.subgraph(name='cluster_pes') as c:
+                for pe in self._processing_elements:
+                    if pe._type_name not in ('in', 'out'):
+                        c.node(pe.entity_name, pe._struct_def())
                 label = (
                     "Processing element"
                     if len(self._processing_elements) <= 1
                     else "Processing elements"
                 )
                 c.attr(label=label)
-
+            if io_cluster:
+                with dg.subgraph(name='cluster_io') as c:
+                    for pe in self._processing_elements:
+                        if pe._type_name in ('in', 'out'):
+                            c.node(pe.entity_name, pe._struct_def())
+                    c.attr(label="IO")
+            else:
+                for pe in self._processing_elements:
+                    if pe._type_name in ('in', 'out'):
+                        dg.node(pe.entity_name, pe._struct_def())
         else:
             for i, mem in enumerate(self._memories):
                 dg.node(mem.entity_name, mem._struct_def())