From 5e5e6e040999bf56821e6be4c920b31bc6848880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ivar=20H=C3=A4rnqvist?= <ivaha717@student.liu.se> Date: Fri, 24 Apr 2020 18:19:40 +0200 Subject: [PATCH] Refactor constructor so that Input signals and Output signals are connected to ports before traversal is started, that way edge cases of empty SFG's are easily handled --- b_asic/operation.py | 10 ++++++++++ b_asic/signal_flow_graph.py | 27 +++++++++++++++++++++++++-- test/fixtures/signal_flow_graph.py | 27 +++++++++++++++++++++++++++ test/test_depends.py | 19 +++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 test/test_depends.py diff --git a/b_asic/operation.py b/b_asic/operation.py index a0d0f48a..bb66e26b 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -180,6 +180,11 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError + @abstractmethod + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -340,6 +345,11 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + if output_index < 0 or output_index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs. + @property def neighbors(self) -> Iterable[GraphComponent]: return list(self.input_signals) + list(self.output_signals) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index d311bfac..bcfc9eeb 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,7 +3,7 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Set +from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet from numbers import Number from collections import defaultdict, deque @@ -45,7 +45,7 @@ class SFG(AbstractOperation): _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] - _original_components_to_new: Set[GraphComponent] + _original_components_to_new: MutableSet[GraphComponent] _original_input_signals_to_indices: Dict[Signal, int] _original_output_signals_to_indices: Dict[Signal, int] @@ -155,6 +155,9 @@ class SFG(AbstractOperation): raise ValueError(f"Output signal #{output_index} is missing source in SFG") if signal.source.operation not in self._original_components_to_new: self._add_operation_connected_tree_copy(signal.source.operation) + + # Find dependencies. + def __str__(self) -> str: """Get a string representation of this SFG.""" @@ -233,6 +236,11 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations + + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + if output_index < 0 or output_index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + return self._inputs_required_for_source(self._output_operations[output_index].input(0).signals[0].source, set()) def copy_component(self, *args, **kwargs) -> GraphComponent: return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, @@ -423,3 +431,18 @@ class SFG(AbstractOperation): value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) results[key] = value return value + + def _inputs_required_for_source(self, src: OutputPort, visited: MutableSet[Operation]) -> Sequence[bool]: + if src.operation in visited: + return [] + visited.add(src.operation) + + if isinstance(src.operation, Input): + for i, input_operation in enumerate(self._input_operations): + if input_operation is src.operation: + return [i] + + input_indices = [] + for i in src.operation.inputs_required_for_output(src.index): + input_indices.extend(self._inputs_required_for_source(src.operation.input(i).signals[0].source, visited)) + return input_indices diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index 7a8c4a73..5a0ef25b 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -29,6 +29,33 @@ def sfg_two_inputs_two_outputs(): out2 = Output(add2) return SFG(inputs = [in1, in2], outputs = [out1, out2]) +@pytest.fixture +def sfg_two_inputs_two_outputs_independent(): + """Valid SFG with two inputs and two outputs, where the first output only depends + on the first input and the second output only depends on the second input. + . . + in1-------------------->out1 + . . + . . + . c1--+ . + . | . + . v . + in2------+ add1---->out2 + . | ^ . + . | | . + . +------+ . + . . + out1 = in1 + out2 = in2 + 3 + """ + in1 = Input() + in2 = Input() + c1 = Constant(3) + add1 = in2 + c1 + out1 = Output(in1) + out2 = Output(add1) + return SFG(inputs = [in1, in2], outputs = [out1, out2]) + @pytest.fixture def sfg_nested(): """Valid SFG with two inputs and one output. diff --git a/test/test_depends.py b/test/test_depends.py new file mode 100644 index 00000000..e2691105 --- /dev/null +++ b/test/test_depends.py @@ -0,0 +1,19 @@ +from b_asic import Addition, Butterfly + +class TestDepends: + def test_depends_addition(self): + add1 = Addition() + assert set(add1.inputs_required_for_output(0)) == {0, 1} + + def test_depends_butterfly(self): + bfly1 = Butterfly() + assert set(bfly1.inputs_required_for_output(0)) == {0, 1} + assert set(bfly1.inputs_required_for_output(1)) == {0, 1} + + def test_depends_sfg(self, sfg_two_inputs_two_outputs): + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(0)) == {0, 1} + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(1)) == {0, 1} + + def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): + assert set(sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0)) == {0} + assert set(sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1)) == {1} \ No newline at end of file -- GitLab