Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
...@@ -59,7 +59,7 @@ documentation = "https://da.gitlab-pages.liu.se/B-ASIC/" ...@@ -59,7 +59,7 @@ documentation = "https://da.gitlab-pages.liu.se/B-ASIC/"
skip-string-normalization = true skip-string-normalization = true
preview = true preview = true
line-length = 88 line-length = 88
exclude = ["test/test_gui", "b_asic/scheduler_gui/ui_main_window.py"] exclude = "(test/test_gui/*|b_asic/scheduler_gui/ui_main_window.py)"
[tool.isort] [tool.isort]
profile = "black" profile = "black"
...@@ -77,3 +77,6 @@ precision = 2 ...@@ -77,3 +77,6 @@ precision = 2
[tool.ruff] [tool.ruff]
ignore = ["F403"] ignore = ["F403"]
[tool.typos]
default.extend-identifiers = { addd0 = "addd0", inout = "inout", ArChItEctUrE = "ArChItEctUrE" }
...@@ -180,13 +180,13 @@ def sfg_simple_filter(): ...@@ -180,13 +180,13 @@ def sfg_simple_filter():
in1---->add1----->t1+---->out1 in1---->add1----->t1+---->out1
. . . .
""" """
in1 = Input("IN1") in1 = Input("IN")
cmul1 = ConstantMultiplication(0.5, name="CMUL1") cmul1 = ConstantMultiplication(0.5, name="CMUL")
add1 = Addition(in1, cmul1, "ADD1") add1 = Addition(in1, cmul1, "ADD")
add1.input(1).signals[0].name = "S2" add1.input(1).signals[0].name = "S2"
t1 = Delay(add1, name="T1") t1 = Delay(add1, name="T")
cmul1.input(0).connect(t1, "S1") cmul1.input(0).connect(t1, "S1")
out1 = Output(t1, "OUT1") out1 = Output(t1, "OUT")
return SFG(inputs=[in1], outputs=[out1], name="simple_filter") return SFG(inputs=[in1], outputs=[out1], name="simple_filter")
......
...@@ -20,7 +20,6 @@ def test_is_valid_vhdl_identifier(): ...@@ -20,7 +20,6 @@ def test_is_valid_vhdl_identifier():
"architecture", "architecture",
"Architecture", "Architecture",
"ArChItEctUrE", "ArChItEctUrE",
"architectURE",
"entity", "entity",
"invalid+", "invalid+",
"invalid}", "invalid}",
......
"""B-ASIC test suite for the core operations.""" """B-ASIC test suite for the core operations."""
import pytest import pytest
from b_asic import ( from b_asic import (
SFG,
Absolute, Absolute,
Addition, Addition,
AddSub, AddSub,
...@@ -17,11 +19,10 @@ from b_asic import ( ...@@ -17,11 +19,10 @@ from b_asic import (
Reciprocal, Reciprocal,
RightShift, RightShift,
Shift, Shift,
Sink,
SquareRoot, SquareRoot,
Subtraction, Subtraction,
SymmetricTwoportAdaptor, SymmetricTwoportAdaptor,
Sink,
SFG,
) )
...@@ -407,6 +408,7 @@ class TestDepends: ...@@ -407,6 +408,7 @@ class TestDepends:
assert set(bfly1.inputs_required_for_output(0)) == {0, 1} assert set(bfly1.inputs_required_for_output(0)) == {0, 1}
assert set(bfly1.inputs_required_for_output(1)) == {0, 1} assert set(bfly1.inputs_required_for_output(1)) == {0, 1}
class TestSink: class TestSink:
def test_create_sfg_with_sink(self): def test_create_sfg_with_sink(self):
bfly = Butterfly() bfly = Butterfly()
...@@ -418,4 +420,4 @@ class TestSink: ...@@ -418,4 +420,4 @@ class TestSink:
assert sfg2.output_count == 1 assert sfg2.output_count == 1
assert sfg2.input_count == 2 assert sfg2.input_count == 2
assert sfg.evaluate_output(1, [0,1]) == sfg2.evaluate_output(0, [0,1]) assert sfg.evaluate_output(1, [0, 1]) == sfg2.evaluate_output(0, [0, 1])
...@@ -212,7 +212,7 @@ class TestProcessCollectionPlainMemoryVariable: ...@@ -212,7 +212,7 @@ class TestProcessCollectionPlainMemoryVariable:
fig, ax = plt.subplots() fig, ax = plt.subplots()
collection = ProcessCollection( collection = ProcessCollection(
{ {
# Process starting exactly at scheudle start # Process starting exactly at schedule start
PlainMemoryVariable(0, 0, {0: 0}, "S1"), PlainMemoryVariable(0, 0, {0: 0}, "S1"),
PlainMemoryVariable(0, 0, {0: 5}, "S2"), PlainMemoryVariable(0, 0, {0: 5}, "S2"),
# Process starting somewhere between schedule start and end # Process starting somewhere between schedule start and end
......
""" """
B-ASIC test suite for the schedule module and Schedule class. B-ASIC test suite for the schedule module and Schedule class.
""" """
import re import re
import pytest
import matplotlib.testing.decorators import matplotlib.testing.decorators
import pytest
from b_asic.core_operations import Addition, Butterfly, ConstantMultiplication from b_asic.core_operations import Addition, Butterfly, ConstantMultiplication
from b_asic.process import OperatorProcess from b_asic.process import OperatorProcess
from b_asic.schedule import Schedule from b_asic.schedule import Schedule
from b_asic.sfg_generators import direct_form_fir
from b_asic.signal_flow_graph import SFG from b_asic.signal_flow_graph import SFG
from b_asic.special_operations import Delay, Input, Output from b_asic.special_operations import Delay, Input, Output
from b_asic.sfg_generators import direct_form_fir
class TestInit: class TestInit:
...@@ -293,7 +294,9 @@ class TestSlacks: ...@@ -293,7 +294,9 @@ class TestSlacks:
schedule = Schedule(precedence_sfg_delays, algorithm="ASAP") schedule = Schedule(precedence_sfg_delays, algorithm="ASAP")
schedule.print_slacks() schedule.print_slacks()
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == """Graph ID | Backward | Forward assert (
captured.out
== """Graph ID | Backward | Forward
---------|----------|--------- ---------|----------|---------
add0 | 0 | 0 add0 | 0 | 0
add1 | 0 | 0 add1 | 0 | 0
...@@ -309,6 +312,7 @@ cmul6 | 4 | 0 ...@@ -309,6 +312,7 @@ cmul6 | 4 | 0
in0 | oo | 0 in0 | oo | 0
out0 | 0 | oo out0 | 0 | oo
""" """
)
assert captured.err == "" assert captured.err == ""
def test_print_slacks_sorting(self, capsys, precedence_sfg_delays): def test_print_slacks_sorting(self, capsys, precedence_sfg_delays):
...@@ -318,7 +322,9 @@ out0 | 0 | oo ...@@ -318,7 +322,9 @@ out0 | 0 | oo
schedule = Schedule(precedence_sfg_delays, algorithm="ASAP") schedule = Schedule(precedence_sfg_delays, algorithm="ASAP")
schedule.print_slacks(1) schedule.print_slacks(1)
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == """Graph ID | Backward | Forward assert (
captured.out
== """Graph ID | Backward | Forward
---------|----------|--------- ---------|----------|---------
cmul0 | 0 | 1 cmul0 | 0 | 1
add0 | 0 | 0 add0 | 0 | 0
...@@ -334,6 +340,7 @@ cmul4 | 16 | 0 ...@@ -334,6 +340,7 @@ cmul4 | 16 | 0
cmul5 | 16 | 0 cmul5 | 16 | 0
in0 | oo | 0 in0 | oo | 0
""" """
)
assert captured.err == "" assert captured.err == ""
def test_slacks_errors(self, precedence_sfg_delays): def test_slacks_errors(self, precedence_sfg_delays):
...@@ -521,11 +528,15 @@ class TestRescheduling: ...@@ -521,11 +528,15 @@ class TestRescheduling:
assert schedule._start_times["add0"] == 0 assert schedule._start_times["add0"] == 0
assert schedule._start_times["out0"] == 2 assert schedule._start_times["out0"] == 2
def test_reintroduce_delays(self, precedence_sfg_delays, sfg_direct_form_iir_lp_filter): def test_reintroduce_delays(
self, precedence_sfg_delays, sfg_direct_form_iir_lp_filter
):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1) precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3) precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
sfg_direct_form_iir_lp_filter.set_latency_of_type(Addition.type_name(), 1) sfg_direct_form_iir_lp_filter.set_latency_of_type(Addition.type_name(), 1)
sfg_direct_form_iir_lp_filter.set_latency_of_type(ConstantMultiplication.type_name(), 3) sfg_direct_form_iir_lp_filter.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
schedule = Schedule(precedence_sfg_delays, algorithm="ASAP") schedule = Schedule(precedence_sfg_delays, algorithm="ASAP")
sfg = schedule.sfg sfg = schedule.sfg
...@@ -536,20 +547,15 @@ class TestRescheduling: ...@@ -536,20 +547,15 @@ class TestRescheduling:
assert sfg_direct_form_iir_lp_filter.evaluate(5) == sfg.evaluate(5) assert sfg_direct_form_iir_lp_filter.evaluate(5) == sfg.evaluate(5)
fir_sfg = direct_form_fir( fir_sfg = direct_form_fir(
list(range(1, 10)), list(range(1, 10)),
mult_properties={ mult_properties={'latency': 2, 'execution_time': 1},
'latency': 2, add_properties={'latency': 2, 'execution_time': 1},
'execution_time': 1 )
},
add_properties={
'latency': 2,
'execution_time': 1
}
)
schedule = Schedule(fir_sfg, algorithm="ASAP") schedule = Schedule(fir_sfg, algorithm="ASAP")
sfg = schedule.sfg sfg = schedule.sfg
assert fir_sfg.evaluate(5) == sfg.evaluate(5) assert fir_sfg.evaluate(5) == sfg.evaluate(5)
class TestTimeResolution: class TestTimeResolution:
def test_increase_time_resolution( def test_increase_time_resolution(
self, sfg_two_inputs_two_outputs_independent_with_cmul self, sfg_two_inputs_two_outputs_independent_with_cmul
...@@ -694,7 +700,9 @@ class TestProcesses: ...@@ -694,7 +700,9 @@ class TestProcesses:
class TestFigureGeneration: class TestFigureGeneration:
@matplotlib.testing.decorators.image_comparison(['test__get_figure_no_execution_times.png'], remove_text=True) @matplotlib.testing.decorators.image_comparison(
['test__get_figure_no_execution_times.png'], remove_text=True
)
def test__get_figure_no_execution_times(self, secondorder_iir_schedule): def test__get_figure_no_execution_times(self, secondorder_iir_schedule):
return secondorder_iir_schedule._get_figure() return secondorder_iir_schedule._get_figure()
......
...@@ -22,10 +22,10 @@ from b_asic.core_operations import ( ...@@ -22,10 +22,10 @@ from b_asic.core_operations import (
) )
from b_asic.operation import ResultKey from b_asic.operation import ResultKey
from b_asic.save_load_structure import python_to_sfg, sfg_to_python from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.sfg_generators import wdf_allpass
from b_asic.signal_flow_graph import SFG, GraphID from b_asic.signal_flow_graph import SFG, GraphID
from b_asic.simulation import Simulation from b_asic.simulation import Simulation
from b_asic.special_operations import Delay from b_asic.special_operations import Delay
from b_asic.sfg_generators import wdf_allpass
class TestInit: class TestInit:
...@@ -163,15 +163,15 @@ class TestPrintSfg: ...@@ -163,15 +163,15 @@ class TestPrintSfg:
+ "Internal Operations:\n" + "Internal Operations:\n"
+ "--------------------------------------------------------------------" + "--------------------------------------------------------------------"
+ "--------------------------------\n" + "--------------------------------\n"
+ str(sfg_simple_filter.find_by_name("IN1")[0]) + str(sfg_simple_filter.find_by_name("IN")[0])
+ "\n" + "\n"
+ str(sfg_simple_filter.find_by_name("ADD1")[0]) + str(sfg_simple_filter.find_by_name("ADD")[0])
+ "\n" + "\n"
+ str(sfg_simple_filter.find_by_name("T1")[0]) + str(sfg_simple_filter.find_by_name("T")[0])
+ "\n" + "\n"
+ str(sfg_simple_filter.find_by_name("CMUL1")[0]) + str(sfg_simple_filter.find_by_name("CMUL")[0])
+ "\n" + "\n"
+ str(sfg_simple_filter.find_by_name("OUT1")[0]) + str(sfg_simple_filter.find_by_name("OUT")[0])
+ "\n" + "\n"
+ "--------------------------------------------------------------------" + "--------------------------------------------------------------------"
+ "--------------------------------\n" + "--------------------------------\n"
...@@ -816,7 +816,7 @@ class TestConnectExternalSignalsToComponentsSoloComp: ...@@ -816,7 +816,7 @@ class TestConnectExternalSignalsToComponentsSoloComp:
assert not test_sfg.connect_external_signals_to_components() assert not test_sfg.connect_external_signals_to_components()
def test_connect_external_signals_to_components_multiple_operations_after_input( def test_connect_external_signals_to_components_multiple_operations_after_input(
self self,
): ):
""" """
Replaces an SFG with a symmetric two-port adaptor to test when the input Replaces an SFG with a symmetric two-port adaptor to test when the input
...@@ -830,6 +830,7 @@ class TestConnectExternalSignalsToComponentsSoloComp: ...@@ -830,6 +830,7 @@ class TestConnectExternalSignalsToComponentsSoloComp:
assert test_sfg.evaluate(1) == -0.5 assert test_sfg.evaluate(1) == -0.5
assert not test_sfg.connect_external_signals_to_components() assert not test_sfg.connect_external_signals_to_components()
class TestConnectExternalSignalsToComponentsMultipleComp: class TestConnectExternalSignalsToComponentsMultipleComp:
def test_connect_external_signals_to_components_operation_tree( def test_connect_external_signals_to_components_operation_tree(
self, operation_tree self, operation_tree
...@@ -957,11 +958,11 @@ class TestTopologicalOrderOperations: ...@@ -957,11 +958,11 @@ class TestTopologicalOrderOperations:
topological_order = sfg_simple_filter.get_operations_topological_order() topological_order = sfg_simple_filter.get_operations_topological_order()
assert [comp.name for comp in topological_order] == [ assert [comp.name for comp in topological_order] == [
"IN1", "IN",
"ADD1", "ADD",
"T1", "T",
"CMUL1", "CMUL",
"OUT1", "OUT",
] ]
def test_multiple_independent_inputs(self, sfg_two_inputs_two_outputs_independent): def test_multiple_independent_inputs(self, sfg_two_inputs_two_outputs_independent):
...@@ -1006,26 +1007,25 @@ class TestRemove: ...@@ -1006,26 +1007,25 @@ class TestRemove:
assert { assert {
op.name op.name
for op in sfg_simple_filter.find_by_name("T1")[0].subsequent_operations for op in sfg_simple_filter.find_by_name("T")[0].subsequent_operations
} == {"CMUL1", "OUT1"} } == {"CMUL", "OUT"}
assert { assert {
op.name for op in new_sfg.find_by_name("T1")[0].subsequent_operations op.name for op in new_sfg.find_by_name("T")[0].subsequent_operations
} == {"ADD1", "OUT1"} } == {"ADD", "OUT"}
assert { assert {
op.name op.name
for op in sfg_simple_filter.find_by_name("ADD1")[0].preceding_operations for op in sfg_simple_filter.find_by_name("ADD")[0].preceding_operations
} == {"CMUL1", "IN1"} } == {"CMUL", "IN"}
assert { assert {
op.name for op in new_sfg.find_by_name("ADD1")[0].preceding_operations op.name for op in new_sfg.find_by_name("ADD")[0].preceding_operations
} == {"T1", "IN1"} } == {"T", "IN"}
assert "S1" in { assert "S1" in {
sig.name sig.name for sig in sfg_simple_filter.find_by_name("T")[0].output(0).signals
for sig in sfg_simple_filter.find_by_name("T1")[0].output(0).signals
} }
assert "S2" in { assert "S2" in {
sig.name for sig in new_sfg.find_by_name("T1")[0].output(0).signals sig.name for sig in new_sfg.find_by_name("T")[0].output(0).signals
} }
def test_remove_multiple_inputs_outputs(self, butterfly_operation_tree): def test_remove_multiple_inputs_outputs(self, butterfly_operation_tree):
...@@ -1176,7 +1176,7 @@ class TestGetComponentsOfType: ...@@ -1176,7 +1176,7 @@ class TestGetComponentsOfType:
) )
] == [] ] == []
def test_get_multple_operations_of_type(self, sfg_two_inputs_two_outputs): def test_get_multiple_operations_of_type(self, sfg_two_inputs_two_outputs):
assert [ assert [
op.name op.name
for op in sfg_two_inputs_two_outputs.find_by_type_name(Addition.type_name()) for op in sfg_two_inputs_two_outputs.find_by_type_name(Addition.type_name())
...@@ -1216,59 +1216,105 @@ class TestPrecedenceGraph: ...@@ -1216,59 +1216,105 @@ class TestPrecedenceGraph:
class TestSFGGraph: class TestSFGGraph:
def test_sfg(self, sfg_simple_filter): def test_sfg(self, sfg_simple_filter):
res = ( res = """digraph {
'digraph {\n\trankdir=LR splines=spline\n\tin0 [shape=cds]\n\tin0 -> add0' rankdir=LR splines=spline
' [headlabel=0]\n\tout0 [shape=cds]\n\tt0 -> out0\n\tadd0' in0 [label="IN
' [shape=ellipse]\n\tcmul0 -> add0 [headlabel=1]\n\tcmul0' (in0)" shape=cds]
' [shape=ellipse]\n\tadd0 -> t0\n\tt0 [shape=square]\n\tt0 -> cmul0\n}' in0 -> add0 [headlabel=0]
) out0 [label="OUT
assert sfg_simple_filter.sfg_digraph(branch_node=False).source in ( (out0)" shape=cds]
"t0.0" -> out0
"t0.0" [shape=point]
t0 -> "t0.0" [arrowhead=none]
add0 [label="ADD
(add0)" shape=ellipse]
cmul0 -> add0 [headlabel=1]
cmul0 [label="CMUL
(cmul0)" shape=ellipse]
add0 -> t0
t0 [label="T
(t0)" shape=square]
"t0.0" -> cmul0
}"""
assert sfg_simple_filter.sfg_digraph().source in (
res, res,
res + "\n", res + "\n",
) )
def test_sfg_show_id(self, sfg_simple_filter): def test_sfg_show_signal_id(self, sfg_simple_filter):
res = ( res = """digraph {
'digraph {\n\trankdir=LR splines=spline\n\tin0 [shape=cds]\n\tin0 -> add0' rankdir=LR splines=spline
' [label=s0 headlabel=0]\n\tout0 [shape=cds]\n\tt0 -> out0' in0 [label="IN
' [label=s1]\n\tadd0 [shape=ellipse]\n\tcmul0 -> add0 [label=s2' (in0)" shape=cds]
' headlabel=1]\n\tcmul0 [shape=ellipse]\n\tadd0 -> t0 [label=s3]\n\tt0' in0 -> add0 [label=s0 headlabel=0]
' [shape=square]\n\tt0 -> cmul0 [label=s4]\n}' out0 [label="OUT
) (out0)" shape=cds]
"t0.0" -> out0 [label=s1]
assert sfg_simple_filter.sfg_digraph( "t0.0" [shape=point]
show_id=True, branch_node=False t0 -> "t0.0" [arrowhead=none]
).source in ( add0 [label="ADD
(add0)" shape=ellipse]
cmul0 -> add0 [label=s2 headlabel=1]
cmul0 [label="CMUL
(cmul0)" shape=ellipse]
add0 -> t0 [label=s3]
t0 [label="T
(t0)" shape=square]
"t0.0" -> cmul0 [label=s4]
}"""
assert sfg_simple_filter.sfg_digraph(show_signal_id=True).source in (
res, res,
res + "\n", res + "\n",
) )
def test_sfg_branch(self, sfg_simple_filter): def test_sfg_no_branch(self, sfg_simple_filter):
res = ( res = """digraph {
'digraph {\n\trankdir=LR splines=spline\n\tin0 [shape=cds]\n\tin0 -> add0' rankdir=LR splines=spline
' [headlabel=0]\n\tout0 [shape=cds]\n\t"t0.0" -> out0\n\t"t0.0"' in0 [label="IN
' [shape=point]\n\tt0 -> "t0.0" [arrowhead=none]\n\tadd0' (in0)" shape=cds]
' [shape=ellipse]\n\tcmul0 -> add0 [headlabel=1]\n\tcmul0' in0 -> add0 [headlabel=0]
' [shape=ellipse]\n\tadd0 -> t0\n\tt0 [shape=square]\n\t"t0.0" ->' out0 [label="OUT
' cmul0\n}' (out0)" shape=cds]
) t0 -> out0
add0 [label="ADD
assert sfg_simple_filter.sfg_digraph().source in ( (add0)" shape=ellipse]
cmul0 -> add0 [headlabel=1]
cmul0 [label="CMUL
(cmul0)" shape=ellipse]
add0 -> t0
t0 [label="T
(t0)" shape=square]
t0 -> cmul0
}"""
assert sfg_simple_filter.sfg_digraph(branch_node=False).source in (
res, res,
res + "\n", res + "\n",
) )
def test_sfg_no_port_numbering(self, sfg_simple_filter): def test_sfg_no_port_numbering(self, sfg_simple_filter):
res = ( res = """digraph {
'digraph {\n\trankdir=LR splines=spline\n\tin0 [shape=cds]\n\tin0 ->' rankdir=LR splines=spline
' add0\n\tout0 [shape=cds]\n\tt0 -> out0\n\tadd0 [shape=ellipse]\n\tcmul0' in0 [label="IN
' -> add0\n\tcmul0 [shape=ellipse]\n\tadd0 -> t0\n\tt0 [shape=square]\n\tt0' (in0)" shape=cds]
' -> cmul0\n}' in0 -> add0
) out0 [label="OUT
(out0)" shape=cds]
assert sfg_simple_filter.sfg_digraph( "t0.0" -> out0
port_numbering=False, branch_node=False "t0.0" [shape=point]
).source in ( t0 -> "t0.0" [arrowhead=none]
add0 [label="ADD
(add0)" shape=ellipse]
cmul0 -> add0
cmul0 [label="CMUL
(cmul0)" shape=ellipse]
add0 -> t0
t0 [label="T
(t0)" shape=square]
"t0.0" -> cmul0
}"""
assert sfg_simple_filter.sfg_digraph(port_numbering=False).source in (
res, res,
res + "\n", res + "\n",
) )
...@@ -1480,9 +1526,7 @@ class TestUnfold: ...@@ -1480,9 +1526,7 @@ class TestUnfold:
): ):
self.do_tests(sfg_two_inputs_two_outputs_independent) self.do_tests(sfg_two_inputs_two_outputs_independent)
def test_threetapiir( def test_threetapiir(self, sfg_direct_form_iir_lp_filter: SFG):
self, sfg_direct_form_iir_lp_filter: SFG
):
self.do_tests(sfg_direct_form_iir_lp_filter) self.do_tests(sfg_direct_form_iir_lp_filter)
def do_tests(self, sfg: SFG): def do_tests(self, sfg: SFG):
...@@ -1527,7 +1571,7 @@ class TestUnfold: ...@@ -1527,7 +1571,7 @@ class TestUnfold:
ref_values = list(ref[ResultKey(f"{n}")]) ref_values = list(ref[ResultKey(f"{n}")])
# Output n will be split into `factor` output ports, compute the # Output n will be split into `factor` output ports, compute the
# indicies where we find the outputs # indices where we find the outputs
out_indices = [n + k * len(sfg.outputs) for k in range(factor)] out_indices = [n + k * len(sfg.outputs) for k in range(factor)]
u_values = [ u_values = [
[unfolded_results[ResultKey(f"{idx}")][k] for idx in out_indices] [unfolded_results[ResultKey(f"{idx}")][k] for idx in out_indices]
...@@ -1635,6 +1679,7 @@ class TestInsertComponentAfter: ...@@ -1635,6 +1679,7 @@ class TestInsertComponentAfter:
with pytest.raises(ValueError, match="Unknown component:"): with pytest.raises(ValueError, match="Unknown component:"):
sfg.insert_operation_after('foo', SquareRoot()) sfg.insert_operation_after('foo', SquareRoot())
class TestInsertComponentBefore: class TestInsertComponentBefore:
def test_insert_component_before_in_sfg(self, butterfly_operation_tree): def test_insert_component_before_in_sfg(self, butterfly_operation_tree):
sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs))) sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs)))
...@@ -1653,22 +1698,22 @@ class TestInsertComponentBefore: ...@@ -1653,22 +1698,22 @@ class TestInsertComponentBefore:
SquareRoot, SquareRoot,
) )
assert isinstance( assert isinstance(
_sfg.find_by_name("bfly1")[0] _sfg.find_by_name("bfly1")[0].input(0).signals[0].source.operation,
.input(0)
.signals[0]
.source.operation,
SquareRoot, SquareRoot,
) )
assert sfg.find_by_name("bfly1")[0].input(0).signals[ assert (
0 sfg.find_by_name("bfly1")[0].input(0).signals[0].source.operation
].source.operation is sfg.find_by_name("bfly2")[0] is sfg.find_by_name("bfly2")[0]
assert _sfg.find_by_name("bfly1")[0].input(0).signals[ )
0 assert (
].destination.operation is not _sfg.find_by_name("bfly2")[0] _sfg.find_by_name("bfly1")[0].input(0).signals[0].destination.operation
assert _sfg.find_by_id("sqrt0").input(0).signals[ is not _sfg.find_by_name("bfly2")[0]
0 )
].source.operation is _sfg.find_by_name("bfly2")[0] assert (
_sfg.find_by_id("sqrt0").input(0).signals[0].source.operation
is _sfg.find_by_name("bfly2")[0]
)
def test_insert_component_before_mimo_operation_error( def test_insert_component_before_mimo_operation_error(
self, large_operation_tree_names self, large_operation_tree_names
......