Skip to content
Snippets Groups Projects
Commit f9cbafb1 authored by Jacob Wahlman's avatar Jacob Wahlman :ok_hand:
Browse files

Merge branch 'develop' of gitlab.liu.se:PUM_TDDD96/B-ASIC into 87-resize-gui-window

parents 39ff67c3 6dcab2af
No related branches found
No related tags found
1 merge request!46Resolve "Resize GUI Window"
Pipeline #14936 passed
......@@ -240,3 +240,18 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
class MAD(AbstractOperation):
"""Multiply-and-add operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2])
@property
def type_name(self) -> TypeName:
return "mad"
def evaluate(self, a, b, c):
return a * b + c
......@@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider):
"""Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
raise NotImplementedError
@abstractmethod
def to_sfg(self) -> "SFG":
"""Convert the operation into its corresponding SFG.
If the operation is composed by multiple operations, the operation will be split.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from.
......@@ -361,6 +367,30 @@ class AbstractOperation(Operation, AbstractGraphComponent):
pass
return [self]
def to_sfg(self) -> "SFG":
# Import here to avoid circular imports.
from b_asic.special_operations import Input, Output
from b_asic.signal_flow_graph import SFG
inputs = [Input() for i in range(self.input_count)]
try:
last_operations = self.evaluate(*inputs)
if isinstance(last_operations, Operation):
last_operations = [last_operations]
outputs = [Output(o) for o in last_operations]
except TypeError:
operation_copy = self.copy_component()
inputs = []
for i in range(self.input_count):
_input = Input()
operation_copy.input(i).connect(_input)
inputs.append(_input)
outputs = [Output(operation_copy)]
return SFG(inputs=inputs, outputs=outputs)
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
......
......@@ -283,6 +283,9 @@ class SFG(AbstractOperation):
def split(self) -> Iterable[Operation]:
return self.operations
def to_sfg(self) -> 'SFG':
return self
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
......
......@@ -89,4 +89,3 @@ def test_division_overload():
assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
......@@ -6,7 +6,6 @@ from b_asic import \
Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
class TestConstant:
def test_constant_positive(self):
test_operation = Constant(3)
......
import pytest
from b_asic import Constant, Addition
from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot
class TestTraverse:
def test_traverse_single_tree(self, operation):
......@@ -22,4 +22,32 @@ class TestTraverse:
assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
def test_traverse_loop(self, operation_graph_with_cycle):
assert len(list(operation_graph_with_cycle.traverse())) == 8
\ No newline at end of file
assert len(list(operation_graph_with_cycle.traverse())) == 8
class TestToSfg:
def test_convert_mad_to_sfg(self):
mad1 = MAD()
mad1_sfg = mad1.to_sfg()
assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1)
assert len(mad1_sfg.operations) == 6
def test_butterfly_to_sfg(self):
but1 = Butterfly()
but1_sfg = but1.to_sfg()
assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0]
assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1]
assert len(but1_sfg.operations) == 8
def test_add_to_sfg(self):
add1 = Addition()
add1_sfg = add1.to_sfg()
assert len(add1_sfg.operations) == 4
def test_sqrt_to_sfg(self):
sqrt1 = SquareRoot()
sqrt1_sfg = sqrt1.to_sfg()
assert len(sqrt1_sfg.operations) == 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment