diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 3e6cd78755f727d034723e2f02e707225cbe9611..ec7306c6f4c97b5c0377794e48524d09c7ed159b 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -240,3 +240,18 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + +class MAD(AbstractOperation): + """Multiply-and-add operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2]) + + @property + def type_name(self) -> TypeName: + return "mad" + + def evaluate(self, a, b, c): + return a * b + c diff --git a/b_asic/operation.py b/b_asic/operation.py index f8ac22e2a1d26e13365d0d742775de6f1f020057..90e9adeff122d0bfbcde9dc1e0a9126aa42b939e 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider): """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" raise NotImplementedError + @abstractmethod + def to_sfg(self) -> "SFG": + """Convert the operation into its corresponding SFG. + If the operation is composed by multiple operations, the operation will be split. + """ + raise NotImplementedError class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -361,6 +367,30 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] + def to_sfg(self) -> "SFG": + # Import here to avoid circular imports. + from b_asic.special_operations import Input, Output + from b_asic.signal_flow_graph import SFG + + inputs = [Input() for i in range(self.input_count)] + + try: + last_operations = self.evaluate(*inputs) + if isinstance(last_operations, Operation): + last_operations = [last_operations] + outputs = [Output(o) for o in last_operations] + except TypeError: + operation_copy = self.copy_component() + inputs = [] + for i in range(self.input_count): + _input = Input() + operation_copy.input(i).connect(_input) + inputs.append(_input) + + outputs = [Output(operation_copy)] + + return SFG(inputs=inputs, outputs=outputs) + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 6483cfc1476047bcbe897b871cc179990b894c4d..6529dfd7d355f062c18ae16f2307587ec5e4cd80 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -283,6 +283,9 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations + + def to_sfg(self) -> 'SFG': + return self def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index 5423ecdf08c420df5dccc6393c3ad6637961172b..9163fce2a955c7fbc68d5d24de86896d251934da 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -89,4 +89,3 @@ def test_division_overload(): assert isinstance(div3, Division) assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(1).signals == div2.output(0).signals - diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 2eb341da88a851ac0fd26939da64377ea27963a1..6a0493c60965579bd843e0b514bd7f9b9a0e4707 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -6,7 +6,6 @@ from b_asic import \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly - class TestConstant: def test_constant_positive(self): test_operation = Constant(3) diff --git a/test/test_operation.py b/test/test_operation.py index b76ba16d11425c0ce868e4fa0b4c88d9f862e23f..77e9ba3cbd0eaa75886b5a7e5d11f00f6cfeb479 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Constant, Addition +from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot class TestTraverse: def test_traverse_single_tree(self, operation): @@ -22,4 +22,32 @@ class TestTraverse: assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 def test_traverse_loop(self, operation_graph_with_cycle): - assert len(list(operation_graph_with_cycle.traverse())) == 8 \ No newline at end of file + assert len(list(operation_graph_with_cycle.traverse())) == 8 + +class TestToSfg: + def test_convert_mad_to_sfg(self): + mad1 = MAD() + mad1_sfg = mad1.to_sfg() + + assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1) + assert len(mad1_sfg.operations) == 6 + + def test_butterfly_to_sfg(self): + but1 = Butterfly() + but1_sfg = but1.to_sfg() + + assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0] + assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1] + assert len(but1_sfg.operations) == 8 + + def test_add_to_sfg(self): + add1 = Addition() + add1_sfg = add1.to_sfg() + + assert len(add1_sfg.operations) == 4 + + def test_sqrt_to_sfg(self): + sqrt1 = SquareRoot() + sqrt1_sfg = sqrt1.to_sfg() + + assert len(sqrt1_sfg.operations) == 3