Skip to content
Snippets Groups Projects
Commit 5b9551b6 authored by Kevin Scott's avatar Kevin Scott Committed by Angus Lothian
Browse files

Resolve "Operation to SFG Conversion"

parent 85db5737
No related branches found
No related tags found
2 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0
......@@ -240,3 +240,18 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
class MAD(AbstractOperation):
"""Multiply-and-add operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2])
@property
def type_name(self) -> TypeName:
return "mad"
def evaluate(self, a, b, c):
return a * b + c
......@@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider):
"""Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
raise NotImplementedError
@abstractmethod
def to_sfg(self) -> "SFG":
"""Convert the operation into its corresponding SFG.
If the operation is composed by multiple operations, the operation will be split.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from.
......@@ -361,6 +367,30 @@ class AbstractOperation(Operation, AbstractGraphComponent):
pass
return [self]
def to_sfg(self) -> "SFG":
# Import here to avoid circular imports.
from b_asic.special_operations import Input, Output
from b_asic.signal_flow_graph import SFG
inputs = [Input() for i in range(self.input_count)]
try:
last_operations = self.evaluate(*inputs)
if isinstance(last_operations, Operation):
last_operations = [last_operations]
outputs = [Output(o) for o in last_operations]
except TypeError:
operation_copy = self.copy_component()
inputs = []
for i in range(self.input_count):
_input = Input()
operation_copy.input(i).connect(_input)
inputs.append(_input)
outputs = [Output(operation_copy)]
return SFG(inputs=inputs, outputs=outputs)
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
......
......@@ -283,6 +283,9 @@ class SFG(AbstractOperation):
def split(self) -> Iterable[Operation]:
return self.operations
def to_sfg(self) -> 'SFG':
return self
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
......
......@@ -89,4 +89,3 @@ def test_division_overload():
assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
......@@ -6,7 +6,6 @@ from b_asic import \
Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
class TestConstant:
def test_constant_positive(self):
test_operation = Constant(3)
......
import pytest
from b_asic import Constant, Addition
from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot
class TestTraverse:
def test_traverse_single_tree(self, operation):
......@@ -22,4 +22,32 @@ class TestTraverse:
assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
def test_traverse_loop(self, operation_graph_with_cycle):
assert len(list(operation_graph_with_cycle.traverse())) == 8
\ No newline at end of file
assert len(list(operation_graph_with_cycle.traverse())) == 8
class TestToSfg:
def test_convert_mad_to_sfg(self):
mad1 = MAD()
mad1_sfg = mad1.to_sfg()
assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1)
assert len(mad1_sfg.operations) == 6
def test_butterfly_to_sfg(self):
but1 = Butterfly()
but1_sfg = but1.to_sfg()
assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0]
assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1]
assert len(but1_sfg.operations) == 8
def test_add_to_sfg(self):
add1 = Addition()
add1_sfg = add1.to_sfg()
assert len(add1_sfg.operations) == 4
def test_sqrt_to_sfg(self):
sqrt1 = SquareRoot()
sqrt1_sfg = sqrt1.to_sfg()
assert len(sqrt1_sfg.operations) == 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment