Skip to content
Snippets Groups Projects
Commit df17cad4 authored by Frans Skarman's avatar Frans Skarman :tropical_fish: Committed by Oscar Gustafsson
Browse files

Add more unfolding tests

parent 2dfa7a04
No related branches found
No related tags found
1 merge request!195Unfolding
......@@ -26,6 +26,7 @@ from typing import (
)
from graphviz import Digraph
from matplotlib.axes import itertools
from b_asic.graph_component import GraphComponent
from b_asic.operation import (
......@@ -1115,11 +1116,6 @@ class SFG(AbstractOperation):
return new_component
def _add_operation_connected_tree_copy(self, start_op: Operation) -> None:
print(
"Running _add_operation_connected_tree_copy with"
f" {self._operations_dfs_order}"
)
print(f"Start op: {start_op}")
op_stack = deque([start_op])
while op_stack:
original_op = op_stack.pop()
......@@ -1169,8 +1165,14 @@ class SFG(AbstractOperation):
original_signal not in self._original_components_to_new
):
if original_signal.source is None:
dest = (
original_signal.destination.operation.name
if original_signal.destination is not None
else "None"
)
raise ValueError(
"Dangling signal without source in SFG"
f" (destination: {dest})"
)
new_signal = cast(
......@@ -1493,6 +1495,19 @@ class SFG(AbstractOperation):
return Schedule(self, scheduling_algorithm="ASAP").schedule_time
def unfold(self, factor: int) -> "SFG":
"""
Unfolds the SFG `factor` times. Returns a new SFG without modifying the original
Inputs and outputs are ordered with early inputs first. I.e. for an sfg
with n inputs, the first n inputs are the inputs at time t, the next n
inputs are the inputs at time t+1, the next n at t+2 and so on.
Parameters
----------
factor : string, optional
Number of times to unfold
"""
if factor == 0:
raise ValueError("Unrollnig 0 times removes the SFG")
......@@ -1508,23 +1523,36 @@ class SFG(AbstractOperation):
# The rest of the process is easier if we clear the connections of the inputs
# and outputs of all operations
for list in new_ops:
for op in list:
for layer, op_list in enumerate(new_ops):
for op_idx, op in enumerate(op_list):
for input in op.inputs:
input.clear()
for output in op.outputs:
output.clear()
# Walk through the operations, replacing delay nodes with connections
for layer in range(factor):
for op_idx, op in enumerate(self.operations):
suffix = layer
new_ops[layer][
op_idx
].name = f"{new_ops[layer][op_idx].name}_{factor-layer}"
# NOTE: These are overwritten later, but it's useful to debug with them
].name = f"{new_ops[layer][op_idx].name}_{suffix}"
# NOTE: Since these IDs are what show up when printing the graph, it
# is helpful to set them. However, this can cause name collisions when
# names in a graph are already suffixed with _n
new_ops[layer][op_idx].graph_id = GraphID(
f"{new_ops[layer][op_idx].graph_id}_{factor-layer}"
f"{new_ops[layer][op_idx].graph_id}_{suffix}"
)
def sanity_check():
all_ops = [op for op_list in new_ops for op in op_list]
cmul201 = [
op for op in all_ops if op.graph_id == GraphID("cmul2_0_1")
]
if cmul201:
print(f"NOW: {cmul201[0]}")
# Walk through the operations, replacing delay nodes with connections
for layer in range(factor):
for op_idx, op in enumerate(self.operations):
if isinstance(op, Delay):
# Port of the operation feeding into this delay
source_port = op.inputs[0].connected_source
......@@ -1553,7 +1581,6 @@ class SFG(AbstractOperation):
else:
# The new output port we should connect to
new_source_port = source_op_output
new_source_port.clear()
for out_signal in op.outputs[0].signals:
sink_port = out_signal.destination
......@@ -1570,7 +1597,6 @@ class SFG(AbstractOperation):
new_destination = new_dest_op.inputs[
sink_op_output_index
]
new_destination.clear()
new_destination.connect(new_source_port)
else:
# Other opreations need to be re-targeted to the corresponding output in the
......@@ -1599,8 +1625,52 @@ class SFG(AbstractOperation):
target_output
)
print(
f"Connecting {new_ops[layer][op_idx].name} <-"
f" {target_output.operation.name} ({target_output.operation.graph_id})"
)
print(f" |>{new_ops[layer][op_idx]}")
print(f" |<{target_output.operation}")
sanity_check()
all_ops = [op for op_list in new_ops for op in op_list]
all_inputs = [op for op in all_ops if isinstance(op, Input)]
all_outputs = [op for op in all_ops if isinstance(op, Output)]
# To get the input order correct, we need to know the input order in the original
# sfg and which operations they correspond to
input_ids = [op.graph_id for op in self.input_operations]
output_ids = [op.graph_id for op in self.output_operations]
# Re-order the inputs to the correct order. Internal order of the inputs should
# be preserved, i.e. for a graph with 2 inputs (in1, in2), in1 must occur before in2,
# but the "time" order should be reversed. I.e. the input from layer `factor-1` is the
# first input
all_inputs = list(
itertools.chain.from_iterable(
[
[ops[id_idx_map[input_id]] for input_id in input_ids]
for ops in new_ops
]
)
)
# Outputs are not reversed, but need the same treatment
all_outputs = list(
itertools.chain.from_iterable(
[
[ops[id_idx_map[output_id]] for output_id in output_ids]
for ops in new_ops
]
)
)
print("All ops: ")
print(*all_ops, sep="\n")
print("All outputs: ")
print(*all_outputs, sep="\n")
# Sanity check to ensure that no duplicate graph IDs have been created
ids = [op.graph_id for op in all_ops]
assert len(ids) == len(set(ids))
return SFG(inputs=all_inputs, outputs=all_outputs)
......@@ -14,7 +14,7 @@ from b_asic import (
)
# Inputs:
in1 = Input(name="in1")
in1 = Input(name="in_1")
# Outputs:
out1 = Output(name="out1")
......
......@@ -13,6 +13,7 @@ from b_asic import (
Input,
Name,
Output,
Signal,
SignalSourceProvider,
TypeName,
)
......@@ -274,3 +275,34 @@ def precedence_sfg_delays_and_constants():
Output(bfly1.output(1), "OUT2")
return SFG(inputs=[in1], outputs=[out1], name="SFG")
@pytest.fixture
def sfg_two_tap_fir():
# Inputs:
in1 = Input(name="in1")
# Outputs:
out1 = Output(name="out1")
# Operations:
t1 = Delay(initial_value=0, name="t1")
cmul1 = ConstantMultiplication(
value=0.5, name="cmul1", latency_offsets={'in0': None, 'out0': None}
)
add1 = Addition(
name="add1", latency_offsets={'in0': None, 'in1': None, 'out0': None}
)
cmul2 = ConstantMultiplication(
value=0.5, name="cmul2", latency_offsets={'in0': None, 'out0': None}
)
# Signals:
Signal(source=t1.output(0), destination=cmul1.input(0))
Signal(source=in1.output(0), destination=t1.input(0))
Signal(source=in1.output(0), destination=cmul2.input(0))
Signal(source=cmul1.output(0), destination=add1.input(0))
Signal(source=add1.output(0), destination=out1.input(0))
Signal(source=cmul2.output(0), destination=add1.input(1))
return SFG(inputs=[in1], outputs=[out1], name='twotapfir')
import io
import itertools
import random
import re
import string
import sys
from os import path, remove
from typing import Counter, Dict, Type
import pytest
......@@ -23,7 +25,9 @@ from b_asic.core_operations import (
Subtraction,
SymmetricTwoportAdaptor,
)
from b_asic.operation import ResultKey
from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.simulation import Simulation
from b_asic.special_operations import Delay
......@@ -1598,9 +1602,123 @@ class TestCriticalPath:
class TestUnfold:
# QUESTION: Is it possible to run a test on *all* fixtures?
def test_unfolding_by_factor_0_raises(self, sfg_simple_filter: SFG):
with pytest.raises(ValueError):
sfg_simple_filter.unfold(0)
def count_kinds(self, sfg: SFG) -> Dict[Type, int]:
return Counter([type(op) for op in sfg.operations])
# Checks that the number of each kind of operation in sfg2 is multiple*count
# of the same operation in sfg1.
# Filters out delay delays
def assert_counts_is_correct(self, sfg1: SFG, sfg2: SFG, multiple: int):
count1 = self.count_kinds(sfg1)
count2 = self.count_kinds(sfg2)
# Delays should not be duplicated. Check that and then clear them
# Using get to avoid issues if there are no delays in the sfg
assert count1.get(Delay) == count2.get(Delay)
count1[Delay] = 0
count2[Delay] = 0
# Ensure that we aren't missing any keys, or have any extras
assert count1.keys() == count2.keys()
for k in count1.keys():
assert count1[k] * multiple == count2[k]
# This is horrifying, but I can't figure out a way to run the test on multiple fixtures,
# so this is an ugly hack until someone that knows pytest comes along
def test_two_inputs_two_outputs(self, sfg_two_inputs_two_outputs: SFG):
self.do_tests(sfg_two_inputs_two_outputs)
def test_twotapfir(self, sfg_two_tap_fir: SFG):
self.do_tests(sfg_two_tap_fir)
def test_delay(self, sfg_delay: SFG):
self.do_tests(sfg_delay)
def test_sfg_two_inputs_two_outputs_independent(
self, sfg_two_inputs_two_outputs_independent: SFG
):
self.do_tests(sfg_two_inputs_two_outputs_independent)
def do_tests(self, sfg: SFG):
for factor in range(2, 4):
# Ensure that the correct number of operations get created
unfolded = sfg.unfold(factor)
self.assert_counts_is_correct(sfg, unfolded, factor)
# TODO: Add more tests
double_unfolded = sfg.unfold(factor).unfold(factor)
self.assert_counts_is_correct(
sfg, double_unfolded, factor * factor
)
NUM_TESTS = 5
# Evaluate with some random values
# To avoid problems with missing inputs at the end of the sequence,
# we generate i*(some large enough) number
input_list = [
[random.random() for _ in range(0, NUM_TESTS * factor)]
for _ in sfg.inputs
]
print("In: ")
print(input_list)
sim = Simulation(sfg, input_list)
sim.run()
ref = sim.results
print("out: ")
print(list(ref[ResultKey("0")]))
# We have i copies of the inputs, each sourcing their input from the orig
unfolded_input_lists = [
[] for _ in range(len(sfg.inputs) * factor)
]
for t in range(0, NUM_TESTS):
for n in range(0, factor):
for k in range(0, len(sfg.inputs)):
unfolded_input_lists[k + n * len(sfg.inputs)].append(
input_list[k][t * factor + n]
)
sim = Simulation(unfolded, unfolded_input_lists)
sim.run()
unfolded_results = sim.results
print("ref out: ")
print("0: ", unfolded_results[ResultKey("0")])
print("1: ", unfolded_results[ResultKey("1")])
for n, _ in enumerate(sfg.outputs):
# Outputs for an original output
ref_values = list(ref[ResultKey(f"{n}")])
# Output n will be split into `factor` output ports, compute the
# indicies where we find the outputs
out_indices = [n + k * len(sfg.outputs) for k in range(factor)]
print("out indices: ", out_indices)
u_values = [
[
unfolded_results[ResultKey(f"{idx}")][k]
for idx in out_indices
]
for k in range(int(NUM_TESTS))
]
print("u_values: ", u_values)
flat_u_values = list(itertools.chain.from_iterable(u_values))
print("ref_values: ", ref_values)
print("flat u_values: ", flat_u_values)
assert flat_u_values == ref_values
def test_value_error(self, sfg_two_inputs_two_outputs: SFG):
sfg = sfg_two_inputs_two_outputs
with pytest.raises(
ValueError, match="Unrollnig 0 times removes the SFG"
):
sfg.unfold(0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment