diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 5efa8038e6c914d0e8f2023ebf6f05eb58664e5b..e08422a842a84d08dcab58ab03d7f581cb1bc664 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -105,6 +105,10 @@ class AbstractGraphComponent(GraphComponent): self._graph_id = "" self._parameters = {} + def __str__(self): + return f"id: {self.graph_id if self.graph_id else 'no_id'}, \tname: {self.name if self.name else 'no_name'}" + \ + "".join((f", \t{key}: {str(param)}" for key, param in self._parameters.items())) + @property def name(self) -> Name: return self._name diff --git a/b_asic/operation.py b/b_asic/operation.py index 90e9adeff122d0bfbcde9dc1e0a9126aa42b939e..21e7012eaf7a333c5db8a7f8a6c741b3220030b8 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -262,6 +262,47 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Constant, Division return Division(Constant(src) if isinstance(src, Number) else src, self) + def __str__(self): + inputs_dict = dict() + for i, port in enumerate(self.inputs): + if port.signal_count == 0: + inputs_dict[i] = '-' + break + dict_ele = [] + for signal in port.signals: + if signal.source: + if signal.source.operation.graph_id: + dict_ele.append(signal.source.operation.graph_id) + else: + dict_ele.append("no_id") + else: + if signal.graph_id: + dict_ele.append(signal.graph_id) + else: + dict_ele.append("no_id") + inputs_dict[i] = dict_ele + + outputs_dict = dict() + for i, port in enumerate(self.outputs): + if port.signal_count == 0: + outputs_dict[i] = '-' + break + dict_ele = [] + for signal in port.signals: + if signal.destination: + if signal.destination.operation.graph_id: + dict_ele.append(signal.destination.operation.graph_id) + else: + dict_ele.append("no_id") + else: + if signal.graph_id: + dict_ele.append(signal.graph_id) + else: + dict_ele.append("no_id") + outputs_dict[i] = dict_ele + + return super().__str__() + f", \tinputs: {str(inputs_dict)}, \toutputs: {str(outputs_dict)}" + @property def input_count(self) -> int: return len(self._input_ports) @@ -400,6 +441,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): def neighbors(self) -> Iterable[GraphComponent]: return list(self.input_signals) + list(self.output_signals) + @property + def preceding_operations(self) -> Iterable[Operation]: + """Returns an Iterable of all Operations that are connected to this Operations input ports.""" + return [signal.source.operation for signal in self.input_signals if signal.source] + + @property + def subsequent_operations(self) -> Iterable[Operation]: + """Returns an Iterable of all Operations that are connected to this Operations output ports.""" + return [signal.destination.operation for signal in self.output_signals if signal.destination] + @property def source(self) -> OutputPort: if self.output_count != 1: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 6529dfd7d355f062c18ae16f2307587ec5e4cd80..10a383bd7c0e8f300ef0ac7f045fd500e5e7af92 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -6,13 +6,15 @@ TODO: More info. from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet from numbers import Number from collections import defaultdict, deque +from io import StringIO +from queue import PriorityQueue +import itertools -from b_asic.port import SignalSourceProvider, OutputPort, InputPort +from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation, MutableOutputMap, MutableRegisterMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output, Register -from b_asic.core_operations import Constant class GraphIDGenerator: @@ -41,8 +43,9 @@ class SFG(AbstractOperation): _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] - _components_ordered: List[GraphComponent] - _operations_ordered: List[Operation] + _components_dfs_order: List[GraphComponent] + _operations_dfs_order: List[Operation] + _operations_topological_order: List[Operation] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] @@ -67,8 +70,9 @@ class SFG(AbstractOperation): self._components_by_id = dict() self._components_by_name = defaultdict(list) - self._components_ordered = [] - self._operations_ordered = [] + self._components_dfs_order = [] + self._operations_dfs_order = [] + self._operations_topological_order = [] self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] @@ -151,9 +155,9 @@ class SFG(AbstractOperation): signal.destination.operation) elif new_signal.destination.operation in output_operations_set: # Add directly connected input to output to ordered list. - self._components_ordered.extend( + self._components_dfs_order.extend( [new_signal.source.operation, new_signal, new_signal.destination.operation]) - self._operations_ordered.extend( + self._operations_dfs_order.extend( [new_signal.source.operation, new_signal.destination.operation]) # Search the graph inwards from each output signal. @@ -170,47 +174,18 @@ class SFG(AbstractOperation): def __str__(self) -> str: """Get a string representation of this SFG.""" - output_string = "" - for component in self._components_ordered: - if isinstance(component, Operation): - for key, value in self._components_by_id.items(): - if value is component: - output_string += "id: " + key + ", name: " - - if component.name != None: - output_string += component.name + ", " - else: - output_string += "-, " + string_io = StringIO() + string_io.write(super().__str__() + "\n") + string_io.write("Internal Operations:\n") + line = "-" * 100 + "\n" + string_io.write(line) - if isinstance(component, Constant): - output_string += "value: " + \ - str(component.value) + ", input: [" - else: - output_string += "input: [" - - counter_input = 0 - for input in component.inputs: - counter_input += 1 - for signal in input.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - - if counter_input > 0: - output_string = output_string[:-2] - output_string += "], output: [" - counter_output = 0 - for output in component.outputs: - counter_output += 1 - for signal in output.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - if counter_output > 0: - output_string = output_string[:-2] - output_string += "]\n" - - return output_string + for operation in self.get_operations_topological_order(): + string_io.write(str(operation) + "\n") + + string_io.write(line) + + return string_io.getvalue() def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG": """Get a new independent SFG instance that is identical to this SFG except without any of its external connections.""" @@ -248,7 +223,7 @@ class SFG(AbstractOperation): return value def connect_external_signals_to_components(self) -> bool: - """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG + """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG it is a component off, causing it to become invalid afterwards. Returns True if succesful, False otherwise. """ if len(self.inputs) != len(self.input_operations): raise IndexError(f"Number of inputs does not match the number of input_operations in SFG.") @@ -264,7 +239,7 @@ class SFG(AbstractOperation): dest = input_operation.output(0).signals[0].destination dest.clear() port.signals[0].set_destination(dest) - # For each output_signal, connect it to the corresponding operation + # For each output_signal, connect it to the corresponding operation for port, output_operation in zip(self.outputs, self.output_operations): src = output_operation.input(0).signals[0].source src.clear() @@ -328,12 +303,12 @@ class SFG(AbstractOperation): @property def components(self) -> Iterable[GraphComponent]: """Get all components of this graph in depth-first order.""" - return self._components_ordered + return self._components_dfs_order @property def operations(self) -> Iterable[Operation]: """Get all operations of this graph in depth-first order.""" - return self._operations_ordered + return self._operations_dfs_order def get_components_with_type_name(self, type_name: TypeName) -> List[GraphComponent]: """Get a list with all components in this graph with the specified type_name. @@ -387,8 +362,8 @@ class SFG(AbstractOperation): new_op = None if original_op not in self._original_components_to_new: new_op = self._add_component_unconnected_copy(original_op) - self._components_ordered.append(new_op) - self._operations_ordered.append(new_op) + self._components_dfs_order.append(new_op) + self._operations_dfs_order.append(new_op) else: new_op = self._original_components_to_new[original_op] @@ -402,24 +377,20 @@ class SFG(AbstractOperation): if original_signal in self._original_input_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_destination( - new_op.input(original_input_port.index)) - self._components_ordered.extend( - [new_signal, new_signal.source.operation]) - self._operations_ordered.append( - new_signal.source.operation) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_dfs_order.extend([new_signal, new_signal.source.operation]) + self._operations_dfs_order.append(new_signal.source.operation) # Check if the signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") + raise ValueError("Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy( - original_signal) - new_signal.set_destination( - new_op.input(original_input_port.index)) - self._components_ordered.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_dfs_order.append(new_signal) original_connected_op = original_signal.source.operation # Check if connected Operation has been added before. @@ -429,12 +400,11 @@ class SFG(AbstractOperation): original_signal.source.index)) else: # Create new operation, set signal source to it. - new_connected_op = self._add_component_unconnected_copy( - original_connected_op) - new_signal.set_source(new_connected_op.output( - original_signal.source.index)) - self._components_ordered.append(new_connected_op) - self._operations_ordered.append(new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_source(new_connected_op.output(original_signal.source.index)) + + self._components_dfs_order.append(new_connected_op) + self._operations_dfs_order.append(new_connected_op) # Add connected operation to queue of operations to visit. op_stack.append(original_connected_op) @@ -446,24 +416,20 @@ class SFG(AbstractOperation): if original_signal in self._original_output_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_source( - new_op.output(original_output_port.index)) - self._components_ordered.extend( - [new_signal, new_signal.destination.operation]) - self._operations_ordered.append( - new_signal.destination.operation) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_dfs_order.extend([new_signal, new_signal.destination.operation]) + self._operations_dfs_order.append(new_signal.destination.operation) # Check if signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") + raise ValueError("Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy( - original_signal) - new_signal.set_source( - new_op.output(original_output_port.index)) - self._components_ordered.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_dfs_order.append(new_signal) original_connected_op = original_signal.destination.operation # Check if connected operation has been added. @@ -473,12 +439,11 @@ class SFG(AbstractOperation): original_signal.destination.index)) else: # Create new operation, set destination to it. - new_connected_op = self._add_component_unconnected_copy( - original_connected_op) - new_signal.set_destination(new_connected_op.input( - original_signal.destination.index)) - self._components_ordered.append(new_connected_op) - self._operations_ordered.append(new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + + self._components_dfs_order.append(new_connected_op) + self._operations_dfs_order.append(new_connected_op) # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) @@ -556,16 +521,13 @@ class SFG(AbstractOperation): if key in results: value = results[key] if value is None: - raise RuntimeError( - f"Direct feedback loop detected when evaluating operation.") + raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") return value - results[key] = src.operation.current_output( - src.index, registers, src_prefix) + results[key] = src.operation.current_output(src.index, registers, src_prefix) input_values = [self._evaluate_source( input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] - value = src.operation.evaluate_output( - src.index, input_values, results, registers, src_prefix) + value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) results[key] = value return value @@ -573,7 +535,7 @@ class SFG(AbstractOperation): """Returns a Precedence list of the SFG where each element in n:th the list consists of elements that are executed in the n:th step. If the precedence list already has been calculated for the current SFG then returns the cached version.""" - if self._precedence_list is not None: + if self._precedence_list: return self._precedence_list # Find all operations with only outputs and no inputs. @@ -587,17 +549,9 @@ class SFG(AbstractOperation): return self._precedence_list - def _traverse_for_precedence_list(self, first_iter_ports): + def _traverse_for_precedence_list(self, first_iter_ports: List[OutputPort]) -> List[List[OutputPort]]: # Find dependencies of output ports and input ports. - outports_per_inport = defaultdict(list) - remaining_inports_per_outport = dict() - for op in self.operations: - op_inputs = op.inputs - for out_i, outport in enumerate(op.outputs): - dependendent_indexes = op.inputs_required_for_output(out_i) - remaining_inports_per_outport[outport] = len(dependendent_indexes) - for in_i in dependendent_indexes: - outports_per_inport[op_inputs[in_i]].append(outport) + remaining_inports_per_operation = {op: op.input_count for op in self.operations} # Traverse output ports for precedence curr_iter_ports = first_iter_ports @@ -614,11 +568,113 @@ class SFG(AbstractOperation): new_inport = signal.destination # Don't traverse over Registers if new_inport is not None and not isinstance(new_inport.operation, Register): - for new_outport in outports_per_inport[new_inport]: - remaining_inports_per_outport[new_outport] -= 1 - if remaining_inports_per_outport[new_outport] == 0: - next_iter_ports.append(new_outport) + new_op = new_inport.operation + remaining_inports_per_operation[new_op] -= 1 + if remaining_inports_per_operation[new_op] == 0: + next_iter_ports.extend(new_op.outputs) curr_iter_ports = next_iter_ports return precedence_list + + def print_precedence_graph(self) -> None: + """Prints a representation of the SFG's precedence list to the standard out. + If the precedence list already has been calculated then it uses the cached version, + otherwise it calculates the precedence list and then prints it.""" + precedence_list = self.get_precedence_list() + + line = "-" * 120 + out_str = StringIO() + out_str.write(line) + + printed_ops = set() + + for iter_num, iter in enumerate(precedence_list, start=1): + for outport_num, outport in enumerate(iter, start=1): + if outport not in printed_ops: + # Only print once per operation, even if it has multiple outports + out_str.write("\n") + out_str.write(str(iter_num)) + out_str.write(".") + out_str.write(str(outport_num)) + out_str.write(" \t") + out_str.write(str(outport.operation)) + printed_ops.add(outport) + + out_str.write("\n") + out_str.write(line) + + print(out_str.getvalue()) + + def get_operations_topological_order(self) -> Iterable[Operation]: + """Returns an Iterable of the Operations in the SFG in Topological Order. + Feedback loops makes an absolutely correct Topological order impossible, so an + approximative Topological Order is returned in such cases in this implementation.""" + if self._operations_topological_order: + return self._operations_topological_order + + no_inputs_queue = deque(list(filter(lambda op: op.input_count == 0, self.operations))) + remaining_inports_per_operation = {op: op.input_count for op in self.operations} + + # Maps number of input counts to a queue of seen objects with such a size. + seen_with_inputs_dict = defaultdict(deque) + seen = set() + top_order = [] + + assert len(no_inputs_queue) > 0, "Illegal SFG state, dangling signals in SFG." + + first_op = no_inputs_queue.popleft() + visited = set([first_op]) + p_queue = PriorityQueue() + p_queue.put((-first_op.output_count, first_op)) # Negative priority as max-heap popping is wanted + operations_left = len(self.operations) - 1 + + seen_but_not_visited_count = 0 + + while operations_left > 0: + while not p_queue.empty(): + op = p_queue.get()[1] + + operations_left -= 1 + top_order.append(op) + visited.add(op) + + for neighbor_op in op.subsequent_operations: + if neighbor_op not in visited: + remaining_inports_per_operation[neighbor_op] -= 1 + remaining_inports = remaining_inports_per_operation[neighbor_op] + + if remaining_inports == 0: + p_queue.put((-neighbor_op.output_count, neighbor_op)) + + elif remaining_inports > 0: + if neighbor_op in seen: + seen_with_inputs_dict[remaining_inports + 1].remove(neighbor_op) + else: + seen.add(neighbor_op) + seen_but_not_visited_count += 1 + + seen_with_inputs_dict[remaining_inports].append(neighbor_op) + + # Check if have to fetch Operations from somewhere else since p_queue is empty + if operations_left > 0: + # First check if can fetch from Operations with no input ports + if no_inputs_queue: + new_op = no_inputs_queue.popleft() + p_queue.put((new_op.output_count, new_op)) + + # Else fetch operation with lowest input count that is not zero + elif seen_but_not_visited_count > 0: + for i in itertools.count(start=1): + seen_inputs_queue = seen_with_inputs_dict[i] + if seen_inputs_queue: + new_op = seen_inputs_queue.popleft() + p_queue.put((-new_op.output_count, new_op)) + seen_but_not_visited_count -= 1 + break + else: + raise RuntimeError("Unallowed structure in SFG detected") + + self._operations_topological_order = top_order + + return self._operations_topological_order diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index 5a0ef25b94cec8e3fad9275cccf97882703de330..e2145b0a2a5974222c8c3d740cb3f53d76c7e445 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -1,12 +1,12 @@ import pytest -from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication +from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication, Addition, Butterfly @pytest.fixture def sfg_two_inputs_two_outputs(): """Valid SFG with two inputs and two outputs. - . . + . . in1-------+ +--------->out1 . | | . . v | . @@ -17,9 +17,9 @@ def sfg_two_inputs_two_outputs(): | . ^ . | . | . +------------+ . - . . + . . out1 = in1 + in2 - out2 = in1 + 2 * in2 + out2 = in1 + 2 * in2 """ in1 = Input() in2 = Input() @@ -27,13 +27,14 @@ def sfg_two_inputs_two_outputs(): add2 = add1 + in2 out1 = Output(add1) out2 = Output(add2) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) + return SFG(inputs=[in1, in2], outputs=[out1, out2]) + @pytest.fixture def sfg_two_inputs_two_outputs_independent(): """Valid SFG with two inputs and two outputs, where the first output only depends on the first input and the second output only depends on the second input. - . . + . . in1-------------------->out1 . . . . @@ -44,17 +45,18 @@ def sfg_two_inputs_two_outputs_independent(): . | ^ . . | | . . +------+ . - . . + . . out1 = in1 - out2 = in2 + 3 + out2 = in2 + 3 """ - in1 = Input() - in2 = Input() - c1 = Constant(3) - add1 = in2 + c1 - out1 = Output(in1) - out2 = Output(add1) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) + in1 = Input("IN1") + in2 = Input("IN2") + c1 = Constant(3, "C1") + add1 = Addition(in2, c1, "ADD1") + out1 = Output(in1, "OUT1") + out2 = Output(add1, "OUT2") + return SFG(inputs=[in1, in2], outputs=[out1, out2]) + @pytest.fixture def sfg_nested(): @@ -65,7 +67,7 @@ def sfg_nested(): mac_in2 = Input() mac_in3 = Input() mac_out1 = Output(mac_in1 + mac_in2 * mac_in3) - MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs = [mac_out1]) + MAC = SFG(inputs=[mac_in1, mac_in2, mac_in3], outputs=[mac_out1]) in1 = Input() in2 = Input() @@ -73,7 +75,8 @@ def sfg_nested(): mac2 = MAC(in1, in2, mac1) mac3 = MAC(in1, mac1, mac2) out1 = Output(mac3) - return SFG(inputs = [in1, in2], outputs = [out1]) + return SFG(inputs=[in1, in2], outputs=[out1]) + @pytest.fixture def sfg_delay(): @@ -83,7 +86,8 @@ def sfg_delay(): in1 = Input() reg1 = Register(in1) out1 = Output(reg1) - return SFG(inputs = [in1], outputs = [out1]) + return SFG(inputs=[in1], outputs=[out1]) + @pytest.fixture def sfg_accumulator(): @@ -95,7 +99,8 @@ def sfg_accumulator(): reg = Register() reg.input(0).connect((reg + data_in) * (1 - reset)) data_out = Output(reg) - return SFG(inputs = [data_in, reset], outputs = [data_out]) + return SFG(inputs=[data_in, reset], outputs=[data_out]) + @pytest.fixture def simple_filter(): @@ -105,11 +110,70 @@ def simple_filter(): | | in1>------add1>------reg>------+------out1> """ - in1 = Input() - reg = Register() - constmul1 = ConstantMultiplication(0.5) - add1 = in1 + constmul1 - reg.input(0).connect(add1) + in1 = Input("IN1") + constmul1 = ConstantMultiplication(0.5, name="CMUL1") + add1 = Addition(in1, constmul1, "ADD1") + reg = Register(add1, name="REG1") constmul1.input(0).connect(reg) - out1 = Output(reg) - return SFG(inputs=[in1], outputs=[out1]) + out1 = Output(reg, "OUT1") + return SFG(inputs=[in1], outputs=[out1], name="simple_filter") + + +@pytest.fixture +def precedence_sfg_registers(): + """A sfg with registers and interesting layout for precednce list generation. + + IN1>--->C0>--->ADD1>--->Q1>---+--->A0>--->ADD4>--->OUT1 + ^ | ^ + | T1 | + | | | + ADD2<---<B1<---+--->A1>--->ADD3 + ^ | ^ + | T2 | + | | | + +-----<B2<---+--->A2>-----+ + """ + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + T2 = Register(T1, 0, "T2") + b2 = ConstantMultiplication(2, T2, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(6, T2, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + add4 = Addition(a0, add3, "ADD4") + out1 = Output(add4, "OUT1") + + return SFG(inputs=[in1], outputs=[out1], name="SFG") + + +@pytest.fixture +def precedence_sfg_registers_and_constants(): + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + const1 = Constant(10, "CONST1") # Replace T2 register with a constant + b2 = ConstantMultiplication(2, const1, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(10, const1, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + # Replace ADD4 with a butterfly to test multiple output ports + bfly1 = Butterfly(a0, add3, "BFLY1") + out1 = Output(bfly1.output(0), "OUT1") + out2 = Output(bfly1.output(1), "OUT2") + + return SFG(inputs=[in1], outputs=[out1], name="SFG") diff --git a/test/test_sfg.py b/test/test_sfg.py index 5f86739517b0d4c7bc9b242de24c3777222b51d2..b6625766b2f17ea142a875ffa6f767786cc4ad94 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,4 +1,7 @@ import pytest +import io +import sys + from b_asic import SFG, Signal, Input, Output, Constant, ConstantMultiplication, Addition, Multiplication, Register, \ Butterfly, Subtraction, SquareRoot @@ -54,13 +57,17 @@ class TestPrintSfg: inp2 = Input("INP2") add1 = Addition(inp1, inp2, "ADD1") out1 = Output(add1, "OUT1") - sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") + sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="SFG1") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s1, s2], output: [s3]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: in2, name: INP2, input: [], output: [s2]\n" + \ - "id: out1, name: OUT1, input: [s3], output: []\n" + "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("INP2")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_add_mul(self): inp1 = Input("INP1") @@ -72,12 +79,16 @@ class TestPrintSfg: sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s1, s2], output: [s5]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: in2, name: INP2, input: [], output: [s2]\n" + \ - "id: mul1, name: MUL1, input: [s5, s3], output: [s4]\n" + \ - "id: in3, name: INP3, input: [], output: [s3]\n" + \ - "id: out1, name: OUT1, input: [s4], output: []\n" + "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("INP2")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("INP3")[0]) + "\n" + \ + str(sfg.find_by_name("MUL1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_constant(self): inp1 = Input("INP1") @@ -88,18 +99,27 @@ class TestPrintSfg: sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s3, s1], output: [s2]\n" + \ - "id: c1, name: CONST, value: 3, input: [], output: [s3]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: out1, name: OUT1, input: [s2], output: []\n" + "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("CONST")[0]) + "\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_simple_filter(self, simple_filter): + assert simple_filter.__str__() == \ - 'id: add1, name: , input: [s1, s3], output: [s4]\n' + \ - 'id: in1, name: , input: [], output: [s1]\n' + \ - 'id: cmul1, name: , input: [s5], output: [s3]\n' + \ - 'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \ - 'id: out1, name: , input: [s2], output: []\n' + "id: no_id, \tname: simple_filter, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(simple_filter.find_by_name("IN1")[0]) + "\n" + \ + str(simple_filter.find_by_name("ADD1")[0]) + "\n" + \ + str(simple_filter.find_by_name("REG1")[0]) + "\n" + \ + str(simple_filter.find_by_name("CMUL1")[0]) + "\n" + \ + str(simple_filter.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" class TestDeepCopy: @@ -267,7 +287,7 @@ class TestInsertComponent: _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) assert _sfg.evaluate() != sfg.evaluate() - + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) @@ -275,7 +295,8 @@ class TestInsertComponent: assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3") - assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3") + assert _sfg.find_by_name("constant4")[0].output( + 0).signals[0].destination.operation is not _sfg.find_by_id("add3") assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3") def test_insert_invalid_component_in_sfg(self, large_operation_tree): @@ -304,22 +325,26 @@ class TestInsertComponent: assert len(_sfg.find_by_name("n_bfly")) == 1 # Correctly connected old output -> new input - assert _sfg.find_by_name("bfly3")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] - assert _sfg.find_by_name("bfly3")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output( + 0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output( + 1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] # Correctly connected new input -> old output assert _sfg.find_by_name("n_bfly")[0].input(0).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] assert _sfg.find_by_name("n_bfly")[0].input(1).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] # Correctly connected new output -> next input - assert _sfg.find_by_name("n_bfly")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] - assert _sfg.find_by_name("n_bfly")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output( + 0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output( + 1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] # Correctly connected next input -> new output assert _sfg.find_by_name("bfly2")[0].input(0).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] assert _sfg.find_by_name("bfly2")[0].input(1).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] - + class TestFindComponentsWithTypeName: def test_mac_components(self): inp1 = Input("INP1") @@ -358,28 +383,9 @@ class TestFindComponentsWithTypeName: class TestGetPrecedenceList: - def test_inputs_registers(self): - in1 = Input("IN1") - c0 = ConstantMultiplication(5, in1, "C0") - add1 = Addition(c0, None, "ADD1") - # Not sure what operation "Q" is supposed to be in the example - Q1 = ConstantMultiplication(1, add1, "Q1") - T1 = Register(Q1, 0, "T1") - T2 = Register(T1, 0, "T2") - b2 = ConstantMultiplication(2, T2, "B2") - b1 = ConstantMultiplication(3, T1, "B1") - add2 = Addition(b1, b2, "ADD2") - add1.input(1).connect(add2) - a1 = ConstantMultiplication(4, T1, "A1") - a2 = ConstantMultiplication(6, T2, "A2") - add3 = Addition(a1, a2, "ADD3") - a0 = ConstantMultiplication(7, Q1, "A0") - add4 = Addition(a0, add3, "ADD4") - out1 = Output(add4, "OUT1") - - sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + def test_inputs_registers(self, precedence_sfg_registers): - precedence_list = sfg.get_precedence_list() + precedence_list = precedence_sfg_registers.get_precedence_list() assert len(precedence_list) == 7 @@ -404,30 +410,9 @@ class TestGetPrecedenceList: assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[6]]) == {"ADD4"} - def test_inputs_constants_registers_multiple_outputs(self): - in1 = Input("IN1") - c0 = ConstantMultiplication(5, in1, "C0") - add1 = Addition(c0, None, "ADD1") - # Not sure what operation "Q" is supposed to be in the example - Q1 = ConstantMultiplication(1, add1, "Q1") - T1 = Register(Q1, 0, "T1") - const1 = Constant(10, "CONST1") # Replace T2 register with a constant - b2 = ConstantMultiplication(2, const1, "B2") - b1 = ConstantMultiplication(3, T1, "B1") - add2 = Addition(b1, b2, "ADD2") - add1.input(1).connect(add2) - a1 = ConstantMultiplication(4, T1, "A1") - a2 = ConstantMultiplication(10, const1, "A2") - add3 = Addition(a1, a2, "ADD3") - a0 = ConstantMultiplication(7, Q1, "A0") - # Replace ADD4 with a butterfly to test multiple output ports - bfly1 = Butterfly(a0, add3, "BFLY1") - out1 = Output(bfly1.output(0), "OUT1") - out2 = Output(bfly1.output(1), "OUT2") - - sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + def test_inputs_constants_registers_multiple_outputs(self, precedence_sfg_registers_and_constants): - precedence_list = sfg.get_precedence_list() + precedence_list = precedence_sfg_registers_and_constants.get_precedence_list() assert len(precedence_list) == 7 @@ -502,10 +487,48 @@ class TestGetPrecedenceList: for port in precedence_list[0]]) == {"IN1", "IN2"} assert set([port.operation.key(port.index, port.operation.name) - for port in precedence_list[1]]) == {"NESTED_SFG.0", "CMUL1"} + for port in precedence_list[1]]) == {"CMUL1"} assert set([port.operation.key(port.index, port.operation.name) - for port in precedence_list[2]]) == {"NESTED_SFG.1"} + for port in precedence_list[2]]) == {"NESTED_SFG.0", "NESTED_SFG.1"} + + +class TestPrintPrecedence: + def test_registers(self, precedence_sfg_registers): + sfg = precedence_sfg_registers + + captured_output = io.StringIO() + sys.stdout = captured_output + + sfg.print_precedence_graph() + + sys.stdout = sys.__stdout__ + + captured_output = captured_output.getvalue() + + assert captured_output == \ + "-" * 120 + "\n" + \ + "1.1 \t" + str(sfg.find_by_name("IN1")[0]) + "\n" + \ + "1.2 \t" + str(sfg.find_by_name("T1")[0]) + "\n" + \ + "1.3 \t" + str(sfg.find_by_name("T2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "2.1 \t" + str(sfg.find_by_name("C0")[0]) + "\n" + \ + "2.2 \t" + str(sfg.find_by_name("A1")[0]) + "\n" + \ + "2.3 \t" + str(sfg.find_by_name("B1")[0]) + "\n" + \ + "2.4 \t" + str(sfg.find_by_name("A2")[0]) + "\n" + \ + "2.5 \t" + str(sfg.find_by_name("B2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "3.1 \t" + str(sfg.find_by_name("ADD3")[0]) + "\n" + \ + "3.2 \t" + str(sfg.find_by_name("ADD2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "4.1 \t" + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "5.1 \t" + str(sfg.find_by_name("Q1")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "6.1 \t" + str(sfg.find_by_name("A0")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "7.1 \t" + str(sfg.find_by_name("ADD4")[0]) + "\n" + \ + "-" * 120 + "\n" class TestDepends: @@ -672,3 +695,15 @@ class TestConnectExternalSignalsToComponentsMultipleComp: sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2, 3, 4) == 16 assert not test_sfg.connect_external_signals_to_components() + + +class TestTopologicalOrderOperations: + def test_feedback_sfg(self, simple_filter): + topological_order = simple_filter.get_operations_topological_order() + + assert [comp.name for comp in topological_order] == ["IN1", "ADD1", "REG1", "CMUL1", "OUT1"] + + def test_multiple_independent_inputs(self, sfg_two_inputs_two_outputs_independent): + topological_order = sfg_two_inputs_two_outputs_independent.get_operations_topological_order() + + assert [comp.name for comp in topological_order] == ["IN1", "OUT1", "IN2", "C1", "ADD1", "OUT2"]