Newer
Older
Angus Lothian
committed
import random
Angus Lothian
committed
import string
import sys
Angus Lothian
committed
Angus Lothian
committed
from b_asic import Input, Output, Signal
from b_asic.core_operations import (
Addition,
Butterfly,
Constant,
ConstantMultiplication,
Multiplication,
SquareRoot,
Subtraction,
SymmetricTwoportAdaptor,
from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.sfg_generators import wdf_allpass
from b_asic.signal_flow_graph import SFG, GraphID
from b_asic.special_operations import Delay
Angus Lothian
committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class TestInit:
def test_direct_input_to_output_sfg_construction(self):
in1 = Input("IN1")
out1 = Output(None, "OUT1")
out1.input(0).connect(in1, "S1")
sfg = SFG(inputs=[in1], outputs=[out1]) # in1 ---s1---> out1
assert len(list(sfg.components)) == 3
assert len(list(sfg.operations)) == 2
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_same_signal_input_and_output_sfg_construction(self):
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
s1 = add2.input(0).connect(add1, "S1")
# in1 ---s1---> out1
sfg = SFG(input_signals=[s1], output_signals=[s1])
assert len(list(sfg.components)) == 3
assert len(list(sfg.operations)) == 2
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_outputs_construction(self, operation_tree):
sfg = SFG(outputs=[Output(operation_tree)])
assert len(list(sfg.components)) == 7
assert len(list(sfg.operations)) == 4
assert sfg.input_count == 0
assert sfg.output_count == 1
def test_signals_construction(self, operation_tree):
sfg = SFG(output_signals=[Signal(source=operation_tree.output(0))])
assert len(list(sfg.components)) == 7
assert len(list(sfg.operations)) == 4
assert sfg.input_count == 0
assert sfg.output_count == 1
class TestPrintSfg:
def test_one_addition(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
add1 = Addition(inp1, inp2, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="SFG1")
== "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
+ str(sfg.find_by_name("INP1")[0])
+ "\n"
+ str(sfg.find_by_name("INP2")[0])
+ "\n"
+ str(sfg.find_by_name("ADD1")[0])
+ "\n"
+ str(sfg.find_by_name("OUT1")[0])
+ "\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
Angus Lothian
committed
def test_add_mul(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg")
== "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
+ str(sfg.find_by_name("INP1")[0])
+ "\n"
+ str(sfg.find_by_name("INP2")[0])
+ "\n"
+ str(sfg.find_by_name("ADD1")[0])
+ "\n"
+ str(sfg.find_by_name("INP3")[0])
+ "\n"
+ str(sfg.find_by_name("MUL1")[0])
+ "\n"
+ str(sfg.find_by_name("OUT1")[0])
+ "\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
Angus Lothian
committed
def test_constant(self):
inp1 = Input("INP1")
const1 = Constant(3, "CONST")
add1 = Addition(const1, inp1, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg")
== "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
+ str(sfg.find_by_name("CONST")[0])
+ "\n"
+ str(sfg.find_by_name("INP1")[0])
+ "\n"
+ str(sfg.find_by_name("ADD1")[0])
+ "\n"
+ str(sfg.find_by_name("OUT1")[0])
+ "\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
Angus Lothian
committed
def test_simple_filter(self, sfg_simple_filter):
== "id: no_id, \tname: simple_filter, \tinputs: {0: '-'},"
" \toutputs: {0: '-'}\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
+ "--------------------------------------------------------------------"
+ "--------------------------------\n"
Angus Lothian
committed
class TestDeepCopy:
def test_deep_copy_no_duplicates(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg")
mac_sfg_new = mac_sfg()
assert mac_sfg.name == "mac_sfg"
assert mac_sfg_new.name == ""
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_new.find_by_id(g_id)
assert component.name == component_copy.name
Angus Lothian
committed
def test_deep_copy(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(
inputs=[inp1, inp2],
outputs=[out1],
id_number_offset=100,
name="mac_sfg",
)
Angus Lothian
committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
mac_sfg_new = mac_sfg(name="mac_sfg2")
assert mac_sfg.name == "mac_sfg"
assert mac_sfg_new.name == "mac_sfg2"
assert mac_sfg.id_number_offset == 100
assert mac_sfg_new.id_number_offset == 100
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_new.find_by_id(g_id)
assert component.name == component_copy.name
def test_deep_copy_with_new_sources(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg")
a = Addition(Constant(3), Constant(5))
b = Constant(2)
mac_sfg_new = mac_sfg(a, b)
assert mac_sfg_new.input(0).signals[0].source.operation is a
assert mac_sfg_new.input(1).signals[0].source.operation is b
class TestEvaluateOutput:
def test_evaluate_output(self, operation_tree):
sfg = SFG(outputs=[Output(operation_tree)])
assert sfg.evaluate_output(0, []) == 5
def test_evaluate_output_large(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
assert sfg.evaluate_output(0, []) == 14
def test_evaluate_output_cycle(self, operation_graph_with_cycle):
sfg = SFG(outputs=[Output(operation_graph_with_cycle)])
with pytest.raises(RuntimeError, match="Direct feedback loop detected"):
Angus Lothian
committed
sfg.evaluate_output(0, [])
class TestComponents:
def test_advanced_components(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg")
assert {comp.name for comp in mac_sfg.components} == {
"INP1",
"INP2",
"INP3",
"ADD1",
"ADD2",
"MUL1",
"OUT1",
"S1",
"S2",
"S3",
"S4",
"S5",
"S6",
"S7",
}
Angus Lothian
committed
class TestReplaceOperation:
Angus Lothian
committed
def test_replace_addition_by_id(self, operation_tree):
sfg = SFG(outputs=[Output(operation_tree)])
Angus Lothian
committed
sfg = sfg.replace_operation(Multiplication(name="Multi"), graph_id=component_id)
Angus Lothian
committed
assert component_id not in sfg._components_by_id.keys()
assert "Multi" in sfg._components_by_name.keys()
def test_replace_addition_large_tree(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
Angus Lothian
committed
sfg = sfg.replace_operation(Multiplication(name="Multi"), graph_id=component_id)
Angus Lothian
committed
assert "Multi" in sfg._components_by_name.keys()
assert component_id not in sfg._components_by_id.keys()
def test_replace_no_input_component(self, operation_tree):
sfg = SFG(outputs=[Output(operation_tree)])
Angus Lothian
committed
const_ = sfg.find_by_id(component_id)
sfg = sfg.replace_operation(Constant(1), graph_id=component_id)
Angus Lothian
committed
assert const_ is not sfg.find_by_id(component_id)
def test_no_match_on_replace(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
Angus Lothian
committed
with pytest.raises(
ValueError, match="No operation matching the criteria found"
):
sfg = sfg.replace_operation(
Multiplication(name="Multi"), graph_id=component_id
)
Angus Lothian
committed
def test_not_equal_input(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
Angus Lothian
committed
with pytest.raises(
TypeError,
match="The input count may not differ between the operations",
):
sfg = sfg.replace_operation(
Multiplication(name="Multi"), graph_id=component_id
)
Angus Lothian
committed
class TestInsertComponent:
def test_insert_component_in_sfg(self, large_operation_tree_names):
sfg = SFG(outputs=[Output(large_operation_tree_names)])
sqrt = SquareRoot()
_sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id)
Angus Lothian
committed
assert _sfg.evaluate() != sfg.evaluate()
assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations])
assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations])
sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation,
SquareRoot,
)
assert isinstance(
_sfg.find_by_name("constant4")[0]
.output(0)
.signals[0]
.destination.operation,
SquareRoot,
)
assert sfg.find_by_name("constant4")[0].output(0).signals[
0
].destination.operation is sfg.find_by_id("add2")
assert _sfg.find_by_name("constant4")[0].output(0).signals[
0
].destination.operation is not _sfg.find_by_id("add2")
assert _sfg.find_by_id("sqrt0")
Loading
Loading full blame...