Skip to content
Snippets Groups Projects
Commit 3dee025a authored by Kevin's avatar Kevin
Browse files

Added test cases for split and to_sfg

parents ec6d300b 676a6e96
No related branches found
No related tags found
1 merge request!42Resolve "Operation to SFG Conversion"
Pipeline #14365 failed
......@@ -229,3 +229,18 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
class MAD(AbstractOperation):
"""Multiply-and-add operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2])
@property
def type_name(self) -> TypeName:
return "mad"
def evaluate(self, a, b, c):
return a * b + c
......@@ -181,11 +181,12 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError
@abstractmethod
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
"""Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
def to_sfg(self) -> "SFG":
"""Convert the operation into its corresponding SFG.
If the operation is composed by multiple operations, the operation will be splitted.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from.
TODO: More info.
......@@ -334,7 +335,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.special_operations import Input
try:
result = self.evaluate([Input()] * self.input_count)
result = self.evaluate(*[Input()] * self.input_count)
if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
return result
if isinstance(result, Operation):
......@@ -345,6 +346,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
pass
return [self]
def to_sfg(self) -> "SFG":
pass
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
......
......@@ -89,4 +89,3 @@ def test_division_overload():
assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
......@@ -4,8 +4,7 @@ B-ASIC test suite for the core operations.
from b_asic import \
Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly, MAD
class TestConstant:
def test_constant_positive(self):
......
import pytest
from b_asic import Constant, Addition
from b_asic import Constant, Addition, MAD, Butterfly
class TestTraverse:
def test_traverse_single_tree(self, operation):
......@@ -22,4 +22,31 @@ class TestTraverse:
assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
def test_traverse_loop(self, operation_graph_with_cycle):
assert len(list(operation_graph_with_cycle.traverse())) == 8
\ No newline at end of file
assert len(list(operation_graph_with_cycle.traverse())) == 8
class TestSplit:
def test_split_mad(self):
mad1 = MAD()
split = mad1.split()
assert len(split) == 1
assert len(list(split[0].traverse())) == 10
def test_split_butterfly(self):
but1 = Butterfly()
split = but1.split()
assert len(split) == 2
class TestToSfg:
def test_convert_mad_to_sfg(self):
mad1 = MAD()
mad1_sfg = mad1.to_sfg()
assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1)
assert len(mad1_sfg.operations) == 6
def test_butterfly_to_sfg(self):
but1 = Butterfly()
but1_sfg = but1.to_sfg()
assert but1.evaluate(1,1) == but1_sfg.evaluate(1,1)
assert len(but1_sfg.operations) == 6
import pytest
from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot, Butterfly
from b_asic import SFG, Signal, Input, Output, Constant, Addition, Subtraction, Multiplication, MAD, ConstantMultiplication, Butterfly
class TestInit:
......@@ -245,56 +245,3 @@ class TestReplaceComponents:
assert True
else:
assert False
class TestInsertComponent:
def test_insert_component_in_sfg(self, large_operation_tree_names):
sfg = SFG(outputs=[Output(large_operation_tree_names)])
sqrt = SquareRoot()
_sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id)
assert _sfg.evaluate() != sfg.evaluate()
assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations])
assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations])
assert not isinstance(sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot)
assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot)
assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3")
assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3")
assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3")
def test_insert_invalid_component_in_sfg(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
# Should raise an exception for not matching input count to output count.
add4 = Addition()
with pytest.raises(Exception):
sfg.insert_operation(add4, "c1")
def test_insert_at_output(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
# Should raise an exception for trying to insert an operation after an output.
sqrt = SquareRoot()
with pytest.raises(Exception):
_sfg = sfg.insert_operation(sqrt, "out1")
def test_insert_multiple_output_ports(self, butterfly_operation_tree):
sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs)))
_sfg = sfg.insert_operation(Butterfly(name="New Bfly"), "bfly3")
assert sfg.evaluate() != _sfg.evaluate()
assert len(sfg.find_by_name("New Bfly")) == 0
assert len(_sfg.find_by_name("New Bfly")) == 1
# The old bfly3 becomes bfly4 in the new sfg since it is "moved" back.
assert sfg.find_by_id("bfly4") is None
assert _sfg.find_by_id("bfly4") is not None
assert sfg.find_by_id("bfly3").output(0).signals[0].destination.operation.name is not "New Bfly"
assert _sfg.find_by_id("bfly4").output(0).signals[0].destination.operation.name is "New Bfly"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment