diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 296803e3e55b7b92d85f52b91b30533ecdfbc0b6..3741d90265b5f7a04bcd02ace789a7e1c6697f6d 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -229,3 +229,18 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + +class MAD(AbstractOperation): + """Multiply-and-add operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2]) + + @property + def type_name(self) -> TypeName: + return "mad" + + def evaluate(self, a, b, c): + return a * b + c diff --git a/b_asic/operation.py b/b_asic/operation.py index bb66e26b30a4a14b116800300d5a00d0855945f2..f5dac7b776c314fb833398fe02415d9104a1bf25 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -181,11 +181,12 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def inputs_required_for_output(self, output_index: int) -> Iterable[int]: - """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" + def to_sfg(self) -> "SFG": + """Convert the operation into its corresponding SFG. + If the operation is composed by multiple operations, the operation will be splitted. + """ raise NotImplementedError - class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. TODO: More info. @@ -334,7 +335,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.special_operations import Input try: - result = self.evaluate([Input()] * self.input_count) + result = self.evaluate(*[Input()] * self.input_count) if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): return result if isinstance(result, Operation): @@ -345,6 +346,9 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] + def to_sfg(self) -> "SFG": + pass + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index 5423ecdf08c420df5dccc6393c3ad6637961172b..9163fce2a955c7fbc68d5d24de86896d251934da 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -89,4 +89,3 @@ def test_division_overload(): assert isinstance(div3, Division) assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(1).signals == div2.output(0).signals - diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 4d0039b558e81c5cd74f151f93f0bc0194a702d5..6a1aacdd54abf27fc3e37553bedc4809d871588f 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -4,8 +4,7 @@ B-ASIC test suite for the core operations. from b_asic import \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ - SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly - + SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly, MAD class TestConstant: def test_constant_positive(self): diff --git a/test/test_operation.py b/test/test_operation.py index b76ba16d11425c0ce868e4fa0b4c88d9f862e23f..8c0d26d877e976f81e4c9986014ccd5187c4dba3 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Constant, Addition +from b_asic import Constant, Addition, MAD, Butterfly class TestTraverse: def test_traverse_single_tree(self, operation): @@ -22,4 +22,31 @@ class TestTraverse: assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 def test_traverse_loop(self, operation_graph_with_cycle): - assert len(list(operation_graph_with_cycle.traverse())) == 8 \ No newline at end of file + assert len(list(operation_graph_with_cycle.traverse())) == 8 + +class TestSplit: + def test_split_mad(self): + mad1 = MAD() + split = mad1.split() + assert len(split) == 1 + assert len(list(split[0].traverse())) == 10 + + def test_split_butterfly(self): + but1 = Butterfly() + split = but1.split() + assert len(split) == 2 + +class TestToSfg: + def test_convert_mad_to_sfg(self): + mad1 = MAD() + mad1_sfg = mad1.to_sfg() + + assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1) + assert len(mad1_sfg.operations) == 6 + + def test_butterfly_to_sfg(self): + but1 = Butterfly() + but1_sfg = but1.to_sfg() + + assert but1.evaluate(1,1) == but1_sfg.evaluate(1,1) + assert len(but1_sfg.operations) == 6 diff --git a/test/test_sfg.py b/test/test_sfg.py index ea7eb1b878ac48e8e0f5a537b7c3b155de12a9a6..72503591db1b70c6b68e765ec87f6a3d17fff05a 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,6 +1,6 @@ import pytest -from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot, Butterfly +from b_asic import SFG, Signal, Input, Output, Constant, Addition, Subtraction, Multiplication, MAD, ConstantMultiplication, Butterfly class TestInit: @@ -245,56 +245,3 @@ class TestReplaceComponents: assert True else: assert False - - -class TestInsertComponent: - - def test_insert_component_in_sfg(self, large_operation_tree_names): - sfg = SFG(outputs=[Output(large_operation_tree_names)]) - sqrt = SquareRoot() - - _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) - assert _sfg.evaluate() != sfg.evaluate() - - assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) - assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) - - assert not isinstance(sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) - assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) - - assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3") - assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3") - assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3") - - def test_insert_invalid_component_in_sfg(self, large_operation_tree): - sfg = SFG(outputs=[Output(large_operation_tree)]) - - # Should raise an exception for not matching input count to output count. - add4 = Addition() - with pytest.raises(Exception): - sfg.insert_operation(add4, "c1") - - def test_insert_at_output(self, large_operation_tree): - sfg = SFG(outputs=[Output(large_operation_tree)]) - - # Should raise an exception for trying to insert an operation after an output. - sqrt = SquareRoot() - with pytest.raises(Exception): - _sfg = sfg.insert_operation(sqrt, "out1") - - def test_insert_multiple_output_ports(self, butterfly_operation_tree): - sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs))) - _sfg = sfg.insert_operation(Butterfly(name="New Bfly"), "bfly3") - - assert sfg.evaluate() != _sfg.evaluate() - - assert len(sfg.find_by_name("New Bfly")) == 0 - assert len(_sfg.find_by_name("New Bfly")) == 1 - - # The old bfly3 becomes bfly4 in the new sfg since it is "moved" back. - assert sfg.find_by_id("bfly4") is None - assert _sfg.find_by_id("bfly4") is not None - - assert sfg.find_by_id("bfly3").output(0).signals[0].destination.operation.name is not "New Bfly" - assert _sfg.find_by_id("bfly4").output(0).signals[0].destination.operation.name is "New Bfly" -