diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 16bc99dfaed1d246ca7302e13b887ea2c3e6d216..3e6254d48bd239f1d749b9e88e66503308c19b59 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -10,9 +10,10 @@ from io import StringIO from queue import PriorityQueue import itertools as it from graphviz import Digraph +from graphviz.backend import FORMATS as GRAPHVIZ_FORMATS, ENGINES as GRAPHVIZ_ENGINES from b_asic.port import SignalSourceProvider, OutputPort -from b_asic.operation import Operation, AbstractOperation, ResultKey, DelayMap, MutableResultMap, MutableDelayMap +from b_asic.operation import Operation, AbstractOperation, ResultKey, MutableResultMap, MutableDelayMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output, Delay @@ -846,3 +847,71 @@ class SFG(AbstractOperation): src.index, input_values, results, delays, key_base, bits_override, truncate) results[key] = value return value + + def sfg(self, show_id=False, engine=None) -> Digraph: + """ + Returns a Digraph of the SFG. Can be directly displayed in IPython. + + Parameters + ---------- + show_id : Boolean, optional + If True, the graph_id:s of signals are shown. The default is False. + + engine: string, optional + Graphviz layout engine to be used, see https://graphviz.org/documentation/. + Most common are "dot" and "neato". Default is None leading to dot. + + Returns + ------- + Digraph + Digraph of the SFG. + + """ + dg = Digraph() + dg.attr(rankdir='LR') + if engine: + assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine" + dg.engine = engine + for op in self._components_by_id.values(): + if isinstance(op, Signal): + if show_id: + dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id, label=op.graph_id) + else: + dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id) + else: + if op.type_name() == Delay.type_name(): + dg.node(op.graph_id, shape='square') + else: + dg.node(op.graph_id) + return dg + + def _repr_svg_(self): + return self.sfg()._repr_svg_() + + def show_sfg(self, format=None, show_id=False, engine=None) -> None: + """ + Shows a visual representation of the SFG using the default system viewer. + + Parameters + ---------- + format : string, optional + File format of the generated graph. Output formats can be found at https://www.graphviz.org/doc/info/output.html + Most common are "pdf", "eps", "png", and "svg". Default is None which leads to PDF. + + + show_id : Boolean, optional + If True, the graph_id:s of signals are shown. The default is False. + + engine: string, optional + Graphviz layout engine to be used, see https://graphviz.org/documentation/. + Most common are "dot" and "neato". Default is None leading to dot. + """ + + dg = self.sfg(show_id=show_id) + if format: + assert format in GRAPHVIZ_FORMATS, "Unknown file format" + dg.format = format + if engine: + assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine" + dg.engine = engine + dg.view() diff --git a/test/test_sfg.py b/test/test_sfg.py index 2124daa9f9694d395292485ed1036efe63c9c088..0a5fe96ed0f24607efdcacefdb1a5275cf229d83 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1016,3 +1016,49 @@ class TestGetComponentsOfType: assert [op.name for op in sfg_two_inputs_two_outputs.find_by_type_name(Output.type_name())] \ == ["OUT1", "OUT2"] + + +class TestPrecedenceGraph: + def test_precedence_graph(self, sfg_simple_filter): + res = 'digraph {\n\trankdir=LR\n\tsubgraph cluster_0 ' \ + '{\n\t\tlabel=N1\n\t\t"in1.0" [label=in1]\n\t\t"t1.0" [label=t1]' \ + '\n\t}\n\tsubgraph cluster_1 {\n\t\tlabel=N2\n\t\t"cmul1.0" ' \ + '[label=cmul1]\n\t}\n\tsubgraph cluster_2 ' \ + '{\n\t\tlabel=N3\n\t\t"add1.0" [label=add1]\n\t}\n\t"in1.0" ' \ + '-> add1\n\tadd1 [label=add1 shape=square]\n\tin1 -> "in1.0"' \ + '\n\tin1 [label=in1 shape=square]\n\t"t1.0" -> cmul1\n\tcmul1 ' \ + '[label=cmul1 shape=square]\n\t"t1.0" -> out1\n\tout1 ' \ + '[label=out1 shape=square]\n\tt1Out -> "t1.0"\n\tt1Out ' \ + '[label=t1 shape=square]\n\t"cmul1.0" -> add1\n\tadd1 ' \ + '[label=add1 shape=square]\n\tcmul1 -> "cmul1.0"\n\tcmul1 ' \ + '[label=cmul1 shape=square]\n\t"add1.0" -> t1In\n\tt1In ' \ + '[label=t1 shape=square]\n\tadd1 -> "add1.0"\n\tadd1 ' \ + '[label=add1 shape=square]\n}' + + assert sfg_simple_filter.precedence_graph().source == res + + +class TestSFGGraph: + def test_sfg(self, sfg_simple_filter): + res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> ' \ + 'add1\n\tout1\n\tt1 -> out1\n\tadd1\n\tcmul1 -> ' \ + 'add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 ' \ + '-> cmul1\n}' + + assert sfg_simple_filter.sfg().source == res + + def test_sfg_show_id(self, sfg_simple_filter): + res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> add1 ' \ + '[label=s1]\n\tout1\n\tt1 -> out1 [label=s2]\n\tadd1' \ + '\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 ' \ + '[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}' + + assert sfg_simple_filter.sfg(show_id=True).source == res + + def test_show_sfg_invalid_format(self, sfg_simple_filter): + with pytest.raises(AssertionError): + sfg_simple_filter.show_sfg(format="ppddff") + + def test_show_sfg_invalid_engine(self, sfg_simple_filter): + with pytest.raises(AssertionError): + sfg_simple_filter.show_sfg(engine="ppddff")