diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 417baa249c1ff9ef91e693a4a67a84ed6af7392e..43216d627c85a65c8a121fef1f47309a26b26322 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -26,6 +26,7 @@ from typing import ( ) from graphviz import Digraph +from matplotlib.axes import itertools from b_asic.graph_component import GraphComponent from b_asic.operation import ( @@ -1164,8 +1165,14 @@ class SFG(AbstractOperation): original_signal not in self._original_components_to_new ): if original_signal.source is None: + dest = ( + original_signal.destination.operation.name + if original_signal.destination is not None + else "None" + ) raise ValueError( "Dangling signal without source in SFG" + f" (destination: {dest})" ) new_signal = cast( @@ -1486,3 +1493,162 @@ class SFG(AbstractOperation): from b_asic.schedule import Schedule return Schedule(self, scheduling_algorithm="ASAP").schedule_time + + def unfold(self, factor: int) -> "SFG": + """ + Unfold the SFG *factor* times. Return a new SFG without modifying the original. + + Inputs and outputs are ordered with early inputs first. That is for an SFG + with n inputs, the first n inputs are the inputs at time t, the next n + inputs are the inputs at time t+1, the next n at t+2 and so on. + + Parameters + ---------- + factor : string, optional + Number of times to unfold + """ + + if factor == 0: + raise ValueError("Unfolding 0 times removes the SFG") + + # Make `factor` copies of the sfg + new_ops = [ + [cast(Operation, op.copy_component()) for op in self.operations] + for _ in range(factor) + ] + + id_idx_map = { + op.graph_id: idx for (idx, op) in enumerate(self.operations) + } + + # The rest of the process is easier if we clear the connections of the inputs + # and outputs of all operations + for layer, op_list in enumerate(new_ops): + for op_idx, op in enumerate(op_list): + for input in op.inputs: + input.clear() + for output in op.outputs: + output.clear() + + suffix = layer + + new_ops[layer][ + op_idx + ].name = f"{new_ops[layer][op_idx].name}_{suffix}" + # NOTE: Since these IDs are what show up when printing the graph, it + # is helpful to set them. However, this can cause name collisions when + # names in a graph are already suffixed with _n + new_ops[layer][op_idx].graph_id = GraphID( + f"{new_ops[layer][op_idx].graph_id}_{suffix}" + ) + + # Walk through the operations, replacing delay nodes with connections + for layer in range(factor): + for op_idx, op in enumerate(self.operations): + if isinstance(op, Delay): + # Port of the operation feeding into this delay + source_port = op.inputs[0].connected_source + if source_port is None: + raise ValueError("Dangling delay input port in sfg") + + source_op_idx = id_idx_map[source_port.operation.graph_id] + source_op_output_index = source_port.index + new_source_op = new_ops[layer][source_op_idx] + source_op_output = new_source_op.outputs[ + source_op_output_index + ] + + # If this is the last layer, we need to create a new delay element and connect it instead + # of the copied port + if layer == factor - 1: + delay = Delay(name=op.name) + delay.graph_id = op.graph_id + + # Since we're adding a new operation instead of bypassing as in the + # common case, we also need to hook up the inputs to the delay. + delay.inputs[0].connect(source_op_output) + + new_source_op = delay + new_source_port = new_source_op.outputs[0] + else: + # The new output port we should connect to + new_source_port = source_op_output + + for out_signal in op.outputs[0].signals: + sink_port = out_signal.destination + if sink_port is None: + # It would be weird if we found a signal but it wasn't connected anywere + raise ValueError("Dangling output port in sfg") + + sink_op_idx = id_idx_map[sink_port.operation.graph_id] + sink_op_output_index = sink_port.index + + target_layer = 0 if layer == factor - 1 else layer + 1 + + new_dest_op = new_ops[target_layer][sink_op_idx] + new_destination = new_dest_op.inputs[ + sink_op_output_index + ] + new_destination.connect(new_source_port) + else: + # Other opreations need to be re-targeted to the corresponding output in the + # current layer, as long as that output is not a delay, as that has been solved + # above. + # To avoid double connections, we'll only re-connect inputs + for input_num, original_input in enumerate(op.inputs): + original_source = original_input.connected_source + # We may not always have something connected to the input, if we don't + # we can abort + if original_source is None: + continue + + # delay connections are handled elsewhere + if not isinstance(original_source.operation, Delay): + source_op_idx = id_idx_map[ + original_source.operation.graph_id + ] + source_op_output_idx = original_source.index + + target_output = new_ops[layer][ + source_op_idx + ].outputs[source_op_output_idx] + + new_ops[layer][op_idx].inputs[input_num].connect( + target_output + ) + + all_ops = [op for op_list in new_ops for op in op_list] + + # To get the input order correct, we need to know the input order in the original + # sfg and which operations they correspond to + input_ids = [op.graph_id for op in self.input_operations] + output_ids = [op.graph_id for op in self.output_operations] + + # Re-order the inputs to the correct order. Internal order of the inputs should + # be preserved, i.e. for a graph with 2 inputs (in1, in2), in1 must occur before in2, + # but the "time" order should be reversed. I.e. the input from layer `factor-1` is the + # first input + all_inputs = list( + itertools.chain.from_iterable( + [ + [ops[id_idx_map[input_id]] for input_id in input_ids] + for ops in new_ops + ] + ) + ) + + # Outputs are not reversed, but need the same treatment + all_outputs = list( + itertools.chain.from_iterable( + [ + [ops[id_idx_map[output_id]] for output_id in output_ids] + for ops in new_ops + ] + ) + ) + + # Sanity check to ensure that no duplicate graph IDs have been created + ids = [op.graph_id for op in all_ops] + assert len(ids) == len(set(ids)) + + return SFG(inputs=all_inputs, outputs=all_outputs) diff --git a/examples/twotapfirsfg.py b/examples/twotapfirsfg.py index e111e2a3118369302639d0666507bc20436b748d..87a90e1c340add655b42c14660bed860e85a4ed6 100644 --- a/examples/twotapfirsfg.py +++ b/examples/twotapfirsfg.py @@ -14,21 +14,21 @@ from b_asic import ( ) # Inputs: -in1 = Input(name="in1") +in1 = Input(name="in_1") # Outputs: -out1 = Output(name="") +out1 = Output(name="out1") # Operations: -t1 = Delay(initial_value=0, name="") +t1 = Delay(initial_value=0, name="t1") cmul1 = ConstantMultiplication( - value=0.5, name="cmul2", latency_offsets={'in0': None, 'out0': None} + value=0.5, name="cmul1", latency_offsets={'in0': None, 'out0': None} ) add1 = Addition( - name="", latency_offsets={'in0': None, 'in1': None, 'out0': None} + name="add1", latency_offsets={'in0': None, 'in1': None, 'out0': None} ) cmul2 = ConstantMultiplication( - value=0.5, name="cmul", latency_offsets={'in0': None, 'out0': None} + value=0.5, name="cmul2", latency_offsets={'in0': None, 'out0': None} ) # Signals: diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index eb876cec25ae8c6a2e743af6b8c2542d17f8863d..321e7523c5e7ec59e31cd1affd7384dac8b5b7ad 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -13,6 +13,7 @@ from b_asic import ( Input, Name, Output, + Signal, SignalSourceProvider, TypeName, ) @@ -274,3 +275,34 @@ def precedence_sfg_delays_and_constants(): Output(bfly1.output(1), "OUT2") return SFG(inputs=[in1], outputs=[out1], name="SFG") + + +@pytest.fixture +def sfg_two_tap_fir(): + # Inputs: + in1 = Input(name="in1") + + # Outputs: + out1 = Output(name="out1") + + # Operations: + t1 = Delay(initial_value=0, name="t1") + cmul1 = ConstantMultiplication( + value=0.5, name="cmul1", latency_offsets={'in0': None, 'out0': None} + ) + add1 = Addition( + name="add1", latency_offsets={'in0': None, 'in1': None, 'out0': None} + ) + cmul2 = ConstantMultiplication( + value=0.5, name="cmul2", latency_offsets={'in0': None, 'out0': None} + ) + + # Signals: + + Signal(source=t1.output(0), destination=cmul1.input(0)) + Signal(source=in1.output(0), destination=t1.input(0)) + Signal(source=in1.output(0), destination=cmul2.input(0)) + Signal(source=cmul1.output(0), destination=add1.input(0)) + Signal(source=add1.output(0), destination=out1.input(0)) + Signal(source=cmul2.output(0), destination=add1.input(1)) + return SFG(inputs=[in1], outputs=[out1], name='twotapfir') diff --git a/test/test_sfg.py b/test/test_sfg.py index a4276fa66cc3c5a1fecf0a9d86315d21e14bdb4c..601df1099a2b4b6a5b1dd9eddfa172ae92f5bdde 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,9 +1,11 @@ import io +import itertools import random import re import string import sys from os import path, remove +from typing import Counter, Dict, Type import pytest @@ -23,7 +25,9 @@ from b_asic.core_operations import ( Subtraction, SymmetricTwoportAdaptor, ) +from b_asic.operation import ResultKey from b_asic.save_load_structure import python_to_sfg, sfg_to_python +from b_asic.simulation import Simulation from b_asic.special_operations import Delay @@ -1595,3 +1599,110 @@ class TestCriticalPath: sfg_simple_accumulator.set_latency_of_type(Addition.type_name(), 6) assert sfg_simple_accumulator.critical_path() == 6 + + +class TestUnfold: + def count_kinds(self, sfg: SFG) -> Dict[Type, int]: + return Counter([type(op) for op in sfg.operations]) + + # Checks that the number of each kind of operation in sfg2 is multiple*count + # of the same operation in sfg1. + # Filters out delay delays + def assert_counts_is_correct(self, sfg1: SFG, sfg2: SFG, multiple: int): + count1 = self.count_kinds(sfg1) + count2 = self.count_kinds(sfg2) + + # Delays should not be duplicated. Check that and then clear them + # Using get to avoid issues if there are no delays in the sfg + assert count1.get(Delay) == count2.get(Delay) + count1[Delay] = 0 + count2[Delay] = 0 + + # Ensure that we aren't missing any keys, or have any extras + assert count1.keys() == count2.keys() + + for k in count1.keys(): + assert count1[k] * multiple == count2[k] + + # This is horrifying, but I can't figure out a way to run the test on multiple fixtures, + # so this is an ugly hack until someone that knows pytest comes along + def test_two_inputs_two_outputs(self, sfg_two_inputs_two_outputs: SFG): + self.do_tests(sfg_two_inputs_two_outputs) + + def test_twotapfir(self, sfg_two_tap_fir: SFG): + self.do_tests(sfg_two_tap_fir) + + def test_delay(self, sfg_delay: SFG): + self.do_tests(sfg_delay) + + def test_sfg_two_inputs_two_outputs_independent( + self, sfg_two_inputs_two_outputs_independent: SFG + ): + self.do_tests(sfg_two_inputs_two_outputs_independent) + + def do_tests(self, sfg: SFG): + for factor in range(2, 4): + # Ensure that the correct number of operations get created + unfolded = sfg.unfold(factor) + + self.assert_counts_is_correct(sfg, unfolded, factor) + + double_unfolded = sfg.unfold(factor).unfold(factor) + + self.assert_counts_is_correct( + sfg, double_unfolded, factor * factor + ) + + NUM_TESTS = 5 + # Evaluate with some random values + # To avoid problems with missing inputs at the end of the sequence, + # we generate i*(some large enough) number + input_list = [ + [random.random() for _ in range(0, NUM_TESTS * factor)] + for _ in sfg.inputs + ] + + sim = Simulation(sfg, input_list) + sim.run() + ref = sim.results + + # We have i copies of the inputs, each sourcing their input from the orig + unfolded_input_lists = [ + [] for _ in range(len(sfg.inputs) * factor) + ] + for t in range(0, NUM_TESTS): + for n in range(0, factor): + for k in range(0, len(sfg.inputs)): + unfolded_input_lists[k + n * len(sfg.inputs)].append( + input_list[k][t * factor + n] + ) + + sim = Simulation(unfolded, unfolded_input_lists) + sim.run() + unfolded_results = sim.results + + for n, _ in enumerate(sfg.outputs): + # Outputs for an original output + ref_values = list(ref[ResultKey(f"{n}")]) + + # Output n will be split into `factor` output ports, compute the + # indicies where we find the outputs + out_indices = [n + k * len(sfg.outputs) for k in range(factor)] + u_values = [ + [ + unfolded_results[ResultKey(f"{idx}")][k] + for idx in out_indices + ] + for k in range(int(NUM_TESTS)) + ] + + flat_u_values = list(itertools.chain.from_iterable(u_values)) + + assert flat_u_values == ref_values + + def test_value_error(self, sfg_two_inputs_two_outputs: SFG): + sfg = sfg_two_inputs_two_outputs + with pytest.raises( + ValueError, match="Unfolding 0 times removes the SFG" + ): + sfg.unfold(0)