From f69a4f3c627a5aa3e4d00d56c2e1989a4c61c177 Mon Sep 17 00:00:00 2001 From: Adam Jakobsson <adaja901@student.liu.se> Date: Tue, 14 Apr 2020 08:53:16 +0200 Subject: [PATCH] Refactor constructor so that Input signals and Output signals are connected to ports before traversal is started, that way edge cases of empty SFG's are easily handled --- b_asic/operation.py | 17 +- b_asic/port.py | 19 +- b_asic/signal_flow_graph.py | 382 +++++++++++++++++++++++++++--------- test/test_print_sfg.py | 46 +++++ test/test_sfg.py | 104 +++++++++- 5 files changed, 449 insertions(+), 119 deletions(-) create mode 100644 test/test_print_sfg.py diff --git a/b_asic/operation.py b/b_asic/operation.py index d644dbd3..ed327127 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -166,13 +166,14 @@ class AbstractOperation(Operation, AbstractGraphComponent): if input_sources is not None: source_count = len(input_sources) if source_count != input_count: - raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}") + raise ValueError( + f"Operation expected {input_count} input sources but only got {source_count}") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values. """ @@ -246,11 +247,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): result = self.evaluate(*input_values) if isinstance(result, collections.Sequence): if len(result) != self.output_count: - raise RuntimeError("Operation evaluated to incorrect number of outputs") + raise RuntimeError( + "Operation evaluated to incorrect number of outputs") return result if isinstance(result, Number): if self.output_count != 1: - raise RuntimeError("Operation evaluated to incorrect number of outputs") + raise RuntimeError( + "Operation evaluated to incorrect number of outputs") return [result] raise RuntimeError("Operation evaluated to invalid type") @@ -296,11 +299,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): def source(self) -> OutputPort: if self.output_count != 1: diff = "more" if self.output_count > 1 else "less" - raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") + raise TypeError( + f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) def copy_unconnected(self) -> GraphComponent: new_comp: AbstractOperation = super().copy_unconnected() for name, value in self.params.items(): - new_comp.set_param(name, deepcopy(value)) # pylint: disable=no-member + new_comp.set_param(name, deepcopy( + value)) # pylint: disable=no-member return new_comp diff --git a/b_asic/port.py b/b_asic/port.py index 4f249e3c..103d076a 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -8,6 +8,7 @@ from copy import copy from typing import NewType, Optional, List, Iterable, TYPE_CHECKING from b_asic.signal import Signal +from b_asic.graph_component import Name if TYPE_CHECKING: from b_asic.operation import Operation @@ -144,22 +145,24 @@ class InputPort(AbstractPort): """ return None if self._source_signal is None else self._source_signal.source - def connect(self, src: SignalSourceProvider) -> Signal: + def connect(self, src: SignalSourceProvider, name: Name = "") -> Signal: """Connect the provided signal source to this input port by creating a new signal. Returns the new signal. """ assert self._source_signal is None, "Attempted to connect already connected input port." - return Signal(src.source, self) # self._source_signal is set by the signal constructor. - + # self._source_signal is set by the signal constructor. + return Signal(source=src.source, destination=self, name=name) + @property def value_length(self) -> Optional[int]: """Get the number of bits that this port should truncate received values to.""" return self._value_length - + @value_length.setter def value_length(self, bits: Optional[int]) -> None: """Set the number of bits that this port should truncate received values to.""" - assert bits is None or (isinstance(bits, int) and bits >= 0), "Value length must be non-negative." + assert bits is None or (isinstance( + bits, int) and bits >= 0), "Value length must be non-negative." self._value_length = bits @@ -185,7 +188,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): def add_signal(self, signal: Signal) -> None: assert signal not in self._destination_signals, "Attempted to add already connected signal." self._destination_signals.append(signal) - signal.set_source(self) + signal.set_source(self) def remove_signal(self, signal: Signal) -> None: assert signal in self._destination_signals, "Attempted to remove already removed signal." @@ -195,7 +198,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): def clear(self) -> None: for signal in copy(self._destination_signals): self.remove_signal(signal) - + @property def source(self) -> "OutputPort": - return self \ No newline at end of file + return self diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index a011653f..2f2a0240 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -5,7 +5,7 @@ TODO: More info. from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set from numbers import Number -from collections import defaultdict +from collections import defaultdict, deque from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation @@ -36,7 +36,7 @@ class SFG(AbstractOperation): """Signal flow graph. TODO: More info. """ - + _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] _graph_id_generator: GraphIDGenerator @@ -46,61 +46,115 @@ class SFG(AbstractOperation): _original_input_signals: Dict[Signal, int] _original_output_signals: Dict[Signal, int] - def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \ - inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [], \ - id_number_offset: GraphIDNumber = 0, name: Name = "", \ + def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], + inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], + id_number_offset: GraphIDNumber = 0, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__( - input_count = len(input_signals) + len(inputs), - output_count = len(output_signals) + len(outputs), - name = name, - input_sources = input_sources) + input_count=len(input_signals) + len(inputs), + output_count=len(output_signals) + len(outputs), + name=name, + input_sources=input_sources) self._components_by_id = dict() self._components_by_name = defaultdict(list) + self._components_in_dfs_order = [] self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] - self._original_components_added = set() - self._original_input_signals = {} - self._original_output_signals = {} - - # Setup input operations and signals. - for i, s in enumerate(input_signals): - self._input_operations.append(self._add_component_copy_unconnected(Input())) - self._original_input_signals[s] = i - for i, op in enumerate(inputs, len(input_signals)): - self._input_operations.append(self._add_component_copy_unconnected(op)) - for s in op.output(0).signals: - self._original_input_signals[s] = i - - # Setup output operations and signals. - for i, s in enumerate(output_signals): - self._output_operations.append(self._add_component_copy_unconnected(Output())) - self._original_output_signals[s] = i - for i, op in enumerate(outputs, len(output_signals)): - self._output_operations.append(self._add_component_copy_unconnected(op)) - for s in op.input(0).signals: - self._original_output_signals[s] = i - + # Maps original components to new copied components + self._added_components_mapping = {} + self._original_input_signals_indexes = {} + self._original_output_signals_indexes = {} + self._id_number_offset = id_number_offset + + # Setup input signals. + for input_index, sig in enumerate(input_signals): + assert sig not in self._added_components_mapping, "Duplicate input signals sent to SFG construcctor." + + new_input_op = self._add_component_copy_unconnected(Input()) + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_source(new_input_op.output(0)) + + self._input_operations.append(new_input_op) + self._original_input_signals_indexes[sig] = input_index + + # Setup input operations, starting from indexes ater input signals. + for input_index, input_op in enumerate(inputs, len(input_signals)): + assert input_op not in self._added_components_mapping, "Duplicate input operations sent to SFG constructor." + new_input_op = self._add_component_copy_unconnected(input_op) + + for sig in input_op.output(0).signals: + assert sig not in self._added_components_mapping, "Duplicate input signals connected to input ports sent to SFG construcctor." + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_source(new_input_op.output(0)) + + self._original_input_signals_indexes[sig] = input_index + + self._input_operations.append(new_input_op) + + # Setup output signals. + for output_ind, sig in enumerate(output_signals): + new_out = self._add_component_copy_unconnected(Output()) + if sig in self._added_components_mapping: + # Signal already added when setting up inputs + new_sig = self._added_components_mapping[sig] + new_sig.set_destination(new_out.input(0)) + else: + # New signal has to be created + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_destination(new_out.input(0)) + + self._output_operations.append(new_out) + self._original_output_signals_indexes[sig] = output_ind + + # Setup output operations, starting from indexes after output signals. + for output_ind, output_op in enumerate(outputs, len(output_signals)): + assert output_op not in self._added_components_mapping, "Duplicate output operations sent to SFG constructor." + + new_out = self._add_component_copy_unconnected(output_op) + for sig in output_op.input(0).signals: + if sig in self._added_components_mapping: + # Signal already added when setting up inputs + new_sig = self._added_components_mapping[sig] + new_sig.set_destination(new_out.input(0)) + else: + # New signal has to be created + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_destination(new_out.input(0)) + + self._original_output_signals_indexes[sig] = output_ind + + self._output_operations.append(new_out) + + output_operations_set = set(self._output_operations) + # Search the graph inwards from each input signal. - for s, i in self._original_input_signals.items(): - if s.destination is None: - raise ValueError(f"Input signal #{i} is missing destination in SFG") - if s.destination.operation not in self._original_components_added: - self._add_operation_copy_recursively(s.destination.operation) + for sig, input_index in self._original_input_signals_indexes.items(): + # Check if already added destination. + new_sig = self._added_components_mapping[sig] + if new_sig.destination is not None and new_sig.destination.operation in output_operations_set: + # Add directly connected input to output to dfs order list + self._components_in_dfs_order.extend([ + new_sig.source.operation, new_sig, new_sig.destination.operation]) + elif sig.destination is None: + raise ValueError( + f"Input signal #{input_index} is missing destination in SFG") + elif sig.destination.operation not in self._added_components_mapping: + self._copy_structure_from_operation_dfs( + sig.destination.operation) # Search the graph inwards from each output signal. - for s, i in self._original_output_signals.items(): - if s.source is None: - raise ValueError(f"Output signal #{i} is missing source in SFG") - if s.source.operation not in self._original_components_added: - self._add_operation_copy_recursively(s.source.operation) - - # Search the graph outwards from each operation. - for op in operations: - if op not in self._original_components_added: - self._add_operation_copy_recursively(op) + for sig, output_index in self._original_output_signals_indexes.items(): + # Check if already added source. + mew_sig = self._added_components_mapping[sig] + if new_sig.source is None: + if sig.source is None: + raise ValueError( + f"Output signal #{output_index} is missing source in SFG") + if sig.source.operation not in self._added_components_mapping: + self._copy_structure_from_operation_dfs( + sig.source.operation) @property def type_name(self) -> TypeName: @@ -108,10 +162,11 @@ class SFG(AbstractOperation): def evaluate(self, *args): if len(args) != self.input_count: - raise ValueError("Wrong number of inputs supplied to SFG for evaluation") + raise ValueError( + "Wrong number of inputs supplied to SFG for evaluation") for arg, op in zip(args, self._input_operations): op.value = arg - + result = [] for op in self._output_operations: result.append(self._evaluate_source(op.input(0).signals[0].source)) @@ -122,7 +177,8 @@ class SFG(AbstractOperation): def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: assert i >= 0 and i < self.output_count, "Output index out of range" result = [None] * self.output_count - result[i] = self._evaluate_source(self._output_operations[i].input(0).signals[0].source) + result[i] = self._evaluate_source( + self._output_operations[i].input(0).signals[0].source) return result def split(self) -> Iterable[Operation]: @@ -130,8 +186,8 @@ class SFG(AbstractOperation): @property def components(self) -> Iterable[GraphComponent]: - """Get all components of this graph.""" - return self._components_by_id.values() + """Get all components of this graph in the dfs-traversal order.""" + return self._components_in_dfs_order def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: """Find a graph object based on the entered Graph ID and return it. If no graph @@ -152,58 +208,194 @@ class SFG(AbstractOperation): """ return self._components_by_name.get(name, []) + def deep_copy(self) -> "SFG": + """Returns a deep copy of self.""" + copy = SFG(inputs=self._input_operations, outputs=self._output_operations, + id_number_offset=self._id_number_offset, name=super().name) + + return copy + def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: - assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component" - self._original_components_added.add(original_comp) + + assert original_comp not in self._added_components_mapping, "Tried to add duplicate SFG component" new_comp = original_comp.copy_unconnected() - self._components_by_id[self._graph_id_generator.next_id(new_comp.type_name)] = new_comp + + self._added_components_mapping[original_comp] = new_comp + self._components_by_id[self._graph_id_generator.next_id( + new_comp.type_name)] = new_comp self._components_by_name[new_comp.name].append(new_comp) + return new_comp - def _add_operation_copy_recursively(self, original_op: Operation) -> Operation: - # Add a copy of the operation without any connections. - new_op = self._add_component_copy_unconnected(original_op) - - # Connect input ports. - for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs): - if original_input_port.signal_count < 1: - raise ValueError("Unconnected input port in SFG") - for original_signal in original_input_port.signals: - if original_signal in self._original_input_signals: # Check if the signal is one of the SFG's input signals. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_destination(new_input_port) - new_signal.set_source(self._input_operations[self._original_input_signals[original_signal]].output(0)) - elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_destination(new_input_port) - if original_signal.source is None: - raise ValueError("Dangling signal without source in SFG") - # Recursively add the connected operation. - new_connected_op = self._add_operation_copy_recursively(original_signal.source.operation) - new_signal.set_source(new_connected_op.output(original_signal.source.index)) - - # Connect output ports. - for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs): - for original_signal in original_output_port.signals: - if original_signal in self._original_output_signals: # Check if the signal is one of the SFG's output signals. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_source(new_output_port) - new_signal.set_destination(self._output_operations[self._original_output_signals[original_signal]].input(0)) - elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_source(new_output_port) - if original_signal.destination is None: - raise ValueError("Dangling signal without destination in SFG") - # Recursively add the connected operation. - new_connected_op = self._add_operation_copy_recursively(original_signal.destination.operation) - new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) - - return new_op - + def _copy_structure_from_operation_dfs(self, start_op: Operation): + op_stack = deque([start_op]) + + while op_stack: + original_op = op_stack.pop() + # Add or get the new copy of the operation.. + new_op = None + if original_op not in self._added_components_mapping: + new_op = self._add_component_copy_unconnected(original_op) + self._components_in_dfs_order.append(new_op) + else: + new_op = self._added_components_mapping[original_op] + + # Connect input ports to new signals + for original_input_port in original_op.inputs: + if original_input_port.signal_count < 1: + raise ValueError("Unconnected input port in SFG") + + for original_signal in original_input_port.signals: + + # Check if the signal is one of the SFG's input signals + if original_signal in self._original_input_signals_indexes: + + # New signal already created during first step of constructor + new_signal = self._added_components_mapping[ + original_signal] + new_signal.set_destination( + new_op.input(original_input_port.index)) + + self._components_in_dfs_order.extend( + [new_signal, new_signal.source.operation]) + + # Check if the signal has not been added before + elif original_signal not in self._added_components_mapping: + if original_signal.source is None: + raise ValueError( + "Dangling signal without source in SFG") + + new_signal = self._add_component_copy_unconnected( + original_signal) + new_signal.set_destination( + new_op.input(original_input_port.index)) + + self._components_in_dfs_order.append(new_signal) + + original_connected_op = original_signal.source.operation + # Check if connected Operation has been added before + if original_connected_op in self._added_components_mapping: + # Set source to the already added operations port + new_signal.set_source( + self._added_components_mapping[original_connected_op].output( + original_signal.source.index)) + else: + # Create new operation, set signal source to it + new_connected_op = self._add_component_copy_unconnected( + original_connected_op) + new_signal.set_source(new_connected_op.output( + original_signal.source.index)) + + self._components_in_dfs_order.append( + new_connected_op) + + # Add connected operation to queue of operations to visit + op_stack.append(original_connected_op) + + # Connect output ports + for original_output_port in original_op.outputs: + + for original_signal in original_output_port.signals: + # Check if the signal is one of the SFG's output signals. + if original_signal in self._original_output_signals_indexes: + + # New signal already created during first step of constructor. + new_signal = self._added_components_mapping[ + original_signal] + new_signal.set_source( + new_op.output(original_output_port.index)) + + self._components_in_dfs_order.extend( + [new_signal, new_signal.destination.operation]) + + # Check if signal has not been added before. + elif original_signal not in self._added_components_mapping: + if original_signal.source is None: + raise ValueError( + "Dangling signal without source in SFG") + + new_signal = self._add_component_copy_unconnected( + original_signal) + new_signal.set_source( + new_op.output(original_output_port.index)) + + self._components_in_dfs_order.append(new_signal) + + original_connected_op = original_signal.destination.operation + # Check if connected operation has been added. + if original_connected_op in self._added_components_mapping: + # Set destination to the already connected operations port + new_signal.set_destination( + self._added_components_mapping[original_connected_op].input( + original_signal.destination.index)) + + else: + # Create new operation, set destination to it. + new_connected_op = self._add_component_copy_unconnected( + original_connected_op) + new_signal.set_destination(new_connected_op.input( + original_signal.destination.index)) + + self._components_in_dfs_order.append( + new_connected_op) + + # Add connected operation to the queue of operations to visist + op_stack.append(original_connected_op) + def _evaluate_source(self, src: OutputPort) -> Number: input_values = [] for input_port in src.operation.inputs: input_src = input_port.signals[0].source input_values.append(self._evaluate_source(input_src)) - return src.operation.evaluate_output(src.index, input_values) \ No newline at end of file + return src.operation.evaluate_output(src.index, input_values) + + + def __str__(self): + """Prints operations, inputs and outputs in a SFG + """ + + output_string = "" + + for comp in self._components_in_dfs_order: + if isinstance(comp, Operation): + for key, value in self._components_by_id.items(): + if value is comp: + output_string += "id: " + key + ", name: " + + if comp.name != None: + output_string += comp.name + ", " + else: + output_string += "-, " + + if comp.type_name is "c": + output_string += "value: " + str(comp.value) + ", input: [" + else: + output_string += "input: [" + + counter_input = 0 + for input in comp.inputs: + counter_input += 1 + for signal in input.signals: + for key, value in self._components_by_id.items(): + if value is signal: + output_string += key + ", " + + if counter_input > 0: + output_string = output_string[:-2] + output_string += "], output: [" + counter_output = 0 + for output in comp.outputs: + counter_output += 1 + for signal in output.signals: + for key, value in self._components_by_id.items(): + if value is signal: + output_string += key + ", " + if counter_output > 0: + output_string = output_string[:-2] + output_string += "]\n" + + return output_string + + + diff --git a/test/test_print_sfg.py b/test/test_print_sfg.py new file mode 100644 index 00000000..feb3626e --- /dev/null +++ b/test/test_print_sfg.py @@ -0,0 +1,46 @@ +""" +B-ASIC test suite for printing a SFG +""" + + +from b_asic.signal_flow_graph import SFG +from b_asic.core_operations import Addition, Multiplication, Constant, ConstantAddition +from b_asic.port import InputPort, OutputPort +from b_asic.signal import Signal +from b_asic.special_operations import Input, Output + +import pytest + + +class TestPrintSfg: + def test_print_one_addition(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + add1 = Addition(inp1, inp2, "ADD1") + out1 = Output(add1, "OUT1") + sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") + + assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: out1, name: OUT1, input: [s3], output: []\n") + + def test_print_add_mul(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") + + assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s5]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: mul1, name: MUL1, input: [s5, s3], output: [s4]\nid: in3, name: INP3, input: [], output: [s3]\nid: out1, name: OUT1, input: [s4], output: []\n") + + def test_print_constant(self): + inp1 = Input("INP1") + const1 = Constant(3, "CONST") + add1 = Addition(const1, inp1, "ADD1") + out1 = Output(add1, "OUT1") + + sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") + + assert sfg.__str__() == ("id: add1, name: ADD1, input: [s3, s1], output: [s2]\nid: c1, name: CONST, value: 3, input: [], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: out1, name: OUT1, input: [s2], output: []\n") + + \ No newline at end of file diff --git a/test/test_sfg.py b/test/test_sfg.py index d3daf2e9..af9dfe17 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,9 +1,33 @@ from b_asic import SFG from b_asic.signal import Signal -from b_asic.core_operations import Addition, Constant +from b_asic.core_operations import Addition, Constant, Multiplication from b_asic.special_operations import Input, Output + class TestConstructor: + def test_direct_input_to_output_sfg_construction(self): + inp = Input("INP1") + out = Output(None, "OUT1") + out.input(0).connect(inp, "S1") + + sfg = SFG(inputs=[inp], outputs=[out]) + + assert len(list(sfg.components)) == 3 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + + def test_same_signal_input_and_output_sfg_construction(self): + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + sig1 = add2.input(0).connect(add1, "S1") + + sfg = SFG(input_signals=[sig1], output_signals=[sig1]) + + assert len(list(sfg.components)) == 3 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + def test_outputs_construction(self, operation_tree): outp = Output(operation_tree) sfg = SFG(outputs=[outp]) @@ -20,13 +44,73 @@ class TestConstructor: assert sfg.input_count == 0 assert sfg.output_count == 1 - def test_operations_construction(self, operation_tree): - sfg1 = SFG(operations=[operation_tree]) - sfg2 = SFG(operations=[operation_tree.input(1).signals[0].source.operation]) - assert len(list(sfg1.components)) == 5 - assert len(list(sfg2.components)) == 5 - assert sfg1.input_count == 0 - assert sfg2.input_count == 0 - assert sfg1.output_count == 0 - assert sfg2.output_count == 0 +class TestDeepCopy: + def test_deep_copy_no_duplicates(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + + mac_sfg = SFG(inputs=[inp1, inp2], + outputs=[out1], name="mac_sfg") + + mac_sfg_deep_copy = mac_sfg.deep_copy() + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_deep_copy.find_by_id(g_id) + assert component.name == component_copy.name + + def test_deep_copy(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs=[inp1, inp2], + outputs=[out1], name="mac_sfg") + + mac_sfg_deep_copy = mac_sfg.deep_copy() + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_deep_copy.find_by_id(g_id) + assert component.name == component_copy.name + + +class TestComponents: + + def test_advanced_components(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs=[inp1, inp2], + outputs=[out1], name="mac_sfg") + + assert set([comp.name for comp in mac_sfg.components]) == { + "INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} -- GitLab