diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 82aeb9c82ca3b3e691d7bb9e37328d7980334cb9..cc197d8e3200f0bdeb3b95421d1fcee6f3d0355d 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -1376,17 +1376,32 @@ class SFG(AbstractOperation): results[key] = value return value - def sfg_digraph(self, show_id=False, engine=None) -> Digraph: + def sfg_digraph( + self, + show_id: bool = False, + engine: str = None, + branch_node: bool = False, + port_numbering: bool = True, + splines: str = "spline", + ) -> Digraph: """ - Returns a Digraph of the SFG. Can be directly displayed in IPython. + Returns a Digraph of the SFG. + + Can be directly displayed in IPython. Parameters ---------- - show_id : Boolean, optional - If True, the graph_id:s of signals are shown. The default is False. + show_id : bool, default: False + If True, the graph_id:s of signals are shown. engine : string, optional Graphviz layout engine to be used, see https://graphviz.org/documentation/. Most common are "dot" and "neato". Default is None leading to dot. + branch_node : bool, default: False + Add a branch node in case the fan-out of a signal is two or more. + port_numbering : bool, default: True + Show the port number in case the number of ports (input or output) is two or more. + splines : {"spline", "line", "ortho", "polyline", "curved"}, default: "spline" + Spline style, see https://graphviz.org/docs/attrs/splines/ for more info. Returns ------- @@ -1395,23 +1410,56 @@ class SFG(AbstractOperation): """ dg = Digraph() - dg.attr(rankdir="LR") + dg.attr(rankdir="LR", splines=splines) + branch_nodes = set() if engine is not None: dg.engine = engine for op in self._components_by_id.values(): if isinstance(op, Signal): source = cast(OutputPort, op.source) destination = cast(InputPort, op.destination) - if show_id: - dg.edge( - source.operation.graph_id, - destination.operation.graph_id, - label=op.graph_id, + source_name = ( + source.name + if branch_node and source.signal_count > 1 + else source.operation.graph_id + ) + label = op.graph_id if show_id else None + taillabel = ( + str(source.index) + if source.operation.output_count > 1 + and (not branch_node or source.signal_count == 1) + and port_numbering + else None + ) + headlabel = ( + str(destination.index) + if destination.operation.input_count > 1 and port_numbering + else None + ) + dg.edge( + source_name, + destination.operation.graph_id, + label=label, + taillabel=taillabel, + headlabel=headlabel, + ) + if ( + branch_node + and source.signal_count > 1 + and source_name not in branch_nodes + ): + branch_nodes.add(source_name) + dg.node(source_name, shape='point') + taillabel = ( + str(source.index) + if source.operation.output_count > 1 and port_numbering + else None ) - else: dg.edge( source.operation.graph_id, - destination.operation.graph_id, + source_name, + arrowhead='none', + taillabel=taillabel, ) else: dg.node(op.graph_id, shape=_OPERATION_SHAPE[op.type_name()]) diff --git a/test/test_sfg.py b/test/test_sfg.py index 2c740f8deeb42481511c011d2af3d0a2155a31da..3d402347e7174eedd67f87cfce299dbb44ab60d5 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1233,19 +1233,19 @@ class TestPrecedenceGraph: class TestSFGGraph: def test_sfg(self, sfg_simple_filter): res = ( - 'digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1\n\tout1' - ' [shape=cds]\n\tt1 -> out1\n\tadd1 [shape=ellipse]\n\tcmul1 ->' - ' add1\n\tcmul1 [shape=ellipse]\n\tadd1 -> t1\n\tt1' - ' [shape=square]\n\tt1 -> cmul1\n}' + 'digraph {\n\trankdir=LR splines=spline\n\tin1 [shape=cds]\n\tin1 -> add1' + ' [headlabel=0]\n\tout1 [shape=cds]\n\tt1 -> out1\n\tadd1' + ' [shape=ellipse]\n\tcmul1 -> add1 [headlabel=1]\n\tcmul1' + ' [shape=ellipse]\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 -> cmul1\n}' ) assert sfg_simple_filter.sfg_digraph().source in (res, res + "\n") def test_sfg_show_id(self, sfg_simple_filter): res = ( - 'digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1' - ' [label=s1]\n\tout1 [shape=cds]\n\tt1 -> out1 [label=s2]\n\tadd1' - ' [shape=ellipse]\n\tcmul1 -> add1 [label=s3]\n\tcmul1' - ' [shape=ellipse]\n\tadd1 -> t1 [label=s4]\n\tt1' + 'digraph {\n\trankdir=LR splines=spline\n\tin1 [shape=cds]\n\tin1 -> add1' + ' [label=s1 headlabel=0]\n\tout1 [shape=cds]\n\tt1 -> out1' + ' [label=s2]\n\tadd1 [shape=ellipse]\n\tcmul1 -> add1 [label=s3' + ' headlabel=1]\n\tcmul1 [shape=ellipse]\n\tadd1 -> t1 [label=s4]\n\tt1' ' [shape=square]\n\tt1 -> cmul1 [label=s5]\n}' ) @@ -1254,6 +1254,34 @@ class TestSFGGraph: res + "\n", ) + def test_sfg_branch(self, sfg_simple_filter): + res = ( + 'digraph {\n\trankdir=LR splines=spline\n\tin1 [shape=cds]\n\tin1 -> add1' + ' [headlabel=0]\n\tout1 [shape=cds]\n\t"t1.0" -> out1\n\t"t1.0"' + ' [shape=point]\n\tt1 -> "t1.0" [arrowhead=none]\n\tadd1' + ' [shape=ellipse]\n\tcmul1 -> add1 [headlabel=1]\n\tcmul1' + ' [shape=ellipse]\n\tadd1 -> t1\n\tt1 [shape=square]\n\t"t1.0" ->' + ' cmul1\n}' + ) + + assert sfg_simple_filter.sfg_digraph(branch_node=True).source in ( + res, + res + "\n", + ) + + def test_sfg_no_port_numbering(self, sfg_simple_filter): + res = ( + 'digraph {\n\trankdir=LR splines=spline\n\tin1 [shape=cds]\n\tin1 ->' + ' add1\n\tout1 [shape=cds]\n\tt1 -> out1\n\tadd1 [shape=ellipse]\n\tcmul1' + ' -> add1\n\tcmul1 [shape=ellipse]\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1' + ' -> cmul1\n}' + ) + + assert sfg_simple_filter.sfg_digraph(port_numbering=False).source in ( + res, + res + "\n", + ) + def test_show_sfg_invalid_format(self, sfg_simple_filter): with pytest.raises(ValueError): sfg_simple_filter.show(fmt="ppddff")