Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
Commits on Source (2)
......@@ -240,3 +240,18 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
class MAD(AbstractOperation):
"""Multiply-and-add operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2])
@property
def type_name(self) -> TypeName:
return "mad"
def evaluate(self, a, b, c):
return a * b + c
......@@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider):
"""Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
raise NotImplementedError
@abstractmethod
def to_sfg(self) -> "SFG":
"""Convert the operation into its corresponding SFG.
If the operation is composed by multiple operations, the operation will be split.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from.
......@@ -361,6 +367,30 @@ class AbstractOperation(Operation, AbstractGraphComponent):
pass
return [self]
def to_sfg(self) -> "SFG":
# Import here to avoid circular imports.
from b_asic.special_operations import Input, Output
from b_asic.signal_flow_graph import SFG
inputs = [Input() for i in range(self.input_count)]
try:
last_operations = self.evaluate(*inputs)
if isinstance(last_operations, Operation):
last_operations = [last_operations]
outputs = [Output(o) for o in last_operations]
except TypeError:
operation_copy = self.copy_component()
inputs = []
for i in range(self.input_count):
_input = Input()
operation_copy.input(i).connect(_input)
inputs.append(_input)
outputs = [Output(operation_copy)]
return SFG(inputs=inputs, outputs=outputs)
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
......
......@@ -283,6 +283,9 @@ class SFG(AbstractOperation):
def split(self) -> Iterable[Operation]:
return self.operations
def to_sfg(self) -> 'SFG':
return self
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
......
......@@ -89,4 +89,3 @@ def test_division_overload():
assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
......@@ -6,7 +6,6 @@ from b_asic import \
Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
class TestConstant:
def test_constant_positive(self):
test_operation = Constant(3)
......
import pytest
from b_asic import Constant, Addition
from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot
class TestTraverse:
def test_traverse_single_tree(self, operation):
......@@ -22,4 +22,32 @@ class TestTraverse:
assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
def test_traverse_loop(self, operation_graph_with_cycle):
assert len(list(operation_graph_with_cycle.traverse())) == 8
\ No newline at end of file
assert len(list(operation_graph_with_cycle.traverse())) == 8
class TestToSfg:
def test_convert_mad_to_sfg(self):
mad1 = MAD()
mad1_sfg = mad1.to_sfg()
assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1)
assert len(mad1_sfg.operations) == 6
def test_butterfly_to_sfg(self):
but1 = Butterfly()
but1_sfg = but1.to_sfg()
assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0]
assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1]
assert len(but1_sfg.operations) == 8
def test_add_to_sfg(self):
add1 = Addition()
add1_sfg = add1.to_sfg()
assert len(add1_sfg.operations) == 4
def test_sqrt_to_sfg(self):
sqrt1 = SquareRoot()
sqrt1_sfg = sqrt1.to_sfg()
assert len(sqrt1_sfg.operations) == 3