Skip to content
Snippets Groups Projects
Commit 58084c2a authored by Simon Bjurek's avatar Simon Bjurek
Browse files

Changed sfg generator to use DontCare.

parent f85cd36a
No related branches found
No related tags found
1 merge request!469Add matrix inversion sfg generator
Pipeline #156534 failed
...@@ -1102,6 +1102,7 @@ class MAD(AbstractOperation): ...@@ -1102,6 +1102,7 @@ class MAD(AbstractOperation):
class MADS(AbstractOperation): class MADS(AbstractOperation):
__slots__ = ( __slots__ = (
"_is_add", "_is_add",
"_override_zero_on_src0",
"_src0", "_src0",
"_src1", "_src1",
"_src2", "_src2",
...@@ -1111,6 +1112,7 @@ class MADS(AbstractOperation): ...@@ -1111,6 +1112,7 @@ class MADS(AbstractOperation):
"_execution_time", "_execution_time",
) )
_is_add: Optional[bool] _is_add: Optional[bool]
_override_zero_on_src0: Optional[bool]
_src0: Optional[SignalSourceProvider] _src0: Optional[SignalSourceProvider]
_src1: Optional[SignalSourceProvider] _src1: Optional[SignalSourceProvider]
_src2: Optional[SignalSourceProvider] _src2: Optional[SignalSourceProvider]
...@@ -1124,6 +1126,7 @@ class MADS(AbstractOperation): ...@@ -1124,6 +1126,7 @@ class MADS(AbstractOperation):
def __init__( def __init__(
self, self,
is_add: Optional[bool] = True, is_add: Optional[bool] = True,
override_zero_on_src0: Optional[bool] = False,
src0: Optional[SignalSourceProvider] = None, src0: Optional[SignalSourceProvider] = None,
src1: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None,
src2: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None,
...@@ -1143,13 +1146,23 @@ class MADS(AbstractOperation): ...@@ -1143,13 +1146,23 @@ class MADS(AbstractOperation):
execution_time=execution_time, execution_time=execution_time,
) )
self.set_param("is_add", is_add) self.set_param("is_add", is_add)
self.set_param("override_zero_on_src0", override_zero_on_src0)
@classmethod @classmethod
def type_name(cls) -> TypeName: def type_name(cls) -> TypeName:
return TypeName("mads") return TypeName("mads")
def evaluate(self, a, b, c): def evaluate(self, a, b, c):
return a + b * c if self.is_add else a - b * c if self.is_add:
if self.override_zero_on_src0:
return b * c
else:
return a + b * c
else:
if self.override_zero_on_src0:
return -b * c
else:
return a - b * c
@property @property
def is_add(self) -> bool: def is_add(self) -> bool:
...@@ -1161,11 +1174,21 @@ class MADS(AbstractOperation): ...@@ -1161,11 +1174,21 @@ class MADS(AbstractOperation):
"""Set if operation is an addition.""" """Set if operation is an addition."""
self.set_param("is_add", is_add) self.set_param("is_add", is_add)
@property
def override_zero_on_src0(self) -> bool:
"""Get if operation is overriding a zero on port src0."""
return self.param("override_zero_on_src0")
@override_zero_on_src0.setter
def override_zero_on_src0(self, override_zero_on_src0: bool) -> None:
"""Set if operation is overriding a zero on port src0."""
self.set_param("override_zero_on_src0", override_zero_on_src0)
@property @property
def is_linear(self) -> bool: def is_linear(self) -> bool:
return ( return (
self.input(0).connected_source.operation.is_constant self.input(1).connected_source.operation.is_constant
or self.input(1).connected_source.operation.is_constant or self.input(2).connected_source.operation.is_constant
) )
def swap_io(self) -> None: def swap_io(self) -> None:
...@@ -1598,6 +1621,51 @@ class Shift(AbstractOperation): ...@@ -1598,6 +1621,51 @@ class Shift(AbstractOperation):
self.set_param("value", value) self.set_param("value", value)
class DontCare(AbstractOperation):
r"""
Dont-care operation
Used for ignoring the input to another operation and thus avoiding dangling input nodes.
Parameters
----------
name : Name, optional
Operation name.
"""
__slots__ = "_name"
_name: Name
is_linear = True
def __init__(self, name: Name = ""):
"""Construct a DontCare operation."""
super().__init__(
input_count=0,
output_count=1,
name=name,
latency_offsets={"out0": 0},
)
@classmethod
def type_name(cls) -> TypeName:
return TypeName("dontcare")
def evaluate(self):
return 0
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
def __repr__(self) -> str:
return "DontCare()"
def __str__(self) -> str:
return "dontcare"
class Sink(AbstractOperation): class Sink(AbstractOperation):
r""" r"""
Sink operation. Sink operation.
......
...@@ -12,9 +12,8 @@ from b_asic.core_operations import ( ...@@ -12,9 +12,8 @@ from b_asic.core_operations import (
MADS, MADS,
Addition, Addition,
Butterfly, Butterfly,
ComplexConjugate,
Constant,
ConstantMultiplication, ConstantMultiplication,
DontCare,
Name, Name,
Reciprocal, Reciprocal,
SymmetricTwoportAdaptor, SymmetricTwoportAdaptor,
...@@ -436,7 +435,19 @@ def radix_2_dif_fft(points: int) -> SFG: ...@@ -436,7 +435,19 @@ def radix_2_dif_fft(points: int) -> SFG:
return SFG(inputs=inputs, outputs=outputs) return SFG(inputs=inputs, outputs=outputs)
def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: def ldlt_matrix_inverse(N: int) -> SFG:
"""Generates an SFG for the LDLT matrix inverse algorithm.
Parameters
----------
N : int
Dimension of the square input matrix.
Returns
-------
SFG
Signal Flow Graph
"""
inputs = [] inputs = []
A = [[None for _ in range(N)] for _ in range(N)] A = [[None for _ in range(N)] for _ in range(N)]
for i in range(N): for i in range(N):
...@@ -457,7 +468,7 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: ...@@ -457,7 +468,7 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
# R*di*R^T factorization # R*di*R^T factorization
for i in range(N): for i in range(N):
for k in range(i): for k in range(i):
D[i] = MADS(False, D[i], M[k][i], R[k][i]) D[i] = MADS(False, False, D[i], M[k][i], R[k][i])
D_inv[i] = Reciprocal(D[i]) D_inv[i] = Reciprocal(D[i])
...@@ -465,14 +476,14 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: ...@@ -465,14 +476,14 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
R[i][j] = A[i][j] R[i][j] = A[i][j]
for k in range(i): for k in range(i):
R[i][j] = MADS(False, R[i][j], M[k][i], R[k][j]) R[i][j] = MADS(False, False, R[i][j], M[k][i], R[k][j])
if is_complex: # if is_complex:
M[i][j] = ComplexConjugate(R[i][j]) # M[i][j] = ComplexConjugate(R[i][j])
else: # else:
M[i][j] = R[i][j] M[i][j] = R[i][j]
R[i][j] = MADS(True, Constant(0, name="0"), R[i][j], D_inv[i]) R[i][j] = MADS(True, True, DontCare(), R[i][j], D_inv[i])
# back substitution # back substitution
A_inv = [[None for _ in range(N)] for _ in range(N)] A_inv = [[None for _ in range(N)] for _ in range(N)]
...@@ -481,14 +492,16 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: ...@@ -481,14 +492,16 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
for j in reversed(range(i + 1)): for j in reversed(range(i + 1)):
for k in reversed(range(j + 1, N)): for k in reversed(range(j + 1, N)):
if k == N - 1 and i != j: if k == N - 1 and i != j:
A_inv[j][i] = MADS( A_inv[j][i] = MADS(False, True, DontCare(), R[j][k], A_inv[i][k])
False, Constant(0, name="0"), R[j][k], A_inv[i][k]
)
else: else:
if A_inv[i][k]: if A_inv[i][k]:
A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[i][k]) A_inv[j][i] = MADS(
False, False, A_inv[j][i], R[j][k], A_inv[i][k]
)
else: else:
A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[k][i]) A_inv[j][i] = MADS(
False, False, A_inv[j][i], R[j][k], A_inv[k][i]
)
outputs = [] outputs = []
for i in range(N): for i in range(N):
......
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
from b_asic import ( from b_asic import (
MAD, MAD,
MADS, MADS,
SFG,
Absolute, Absolute,
Addition, Addition,
AddSub, AddSub,
...@@ -13,11 +14,13 @@ from b_asic import ( ...@@ -13,11 +14,13 @@ from b_asic import (
Constant, Constant,
ConstantMultiplication, ConstantMultiplication,
Division, Division,
DontCare,
Input, Input,
LeftShift, LeftShift,
Max, Max,
Min, Min,
Multiplication, Multiplication,
Output,
Reciprocal, Reciprocal,
RightShift, RightShift,
Shift, Shift,
...@@ -343,19 +346,33 @@ class TestMADS: ...@@ -343,19 +346,33 @@ class TestMADS:
test_operation = MADS(is_add=True) test_operation = MADS(is_add=True)
assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -1 + 14j assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -1 + 14j
def test_mads_zero_override(self):
test_operation = MADS(is_add=True, override_zero_on_src0=True)
assert test_operation.evaluate_output(0, [1, 1, 1]) == 1
def test_mads_sub_zero_override(self):
test_operation = MADS(is_add=False, override_zero_on_src0=True)
assert test_operation.evaluate_output(0, [1, 1, 1]) == -1
def test_mads_is_linear(self): def test_mads_is_linear(self):
test_operation = MADS( test_operation = MADS(
Constant(3), Addition(Input(), Constant(3)), Addition(Input(), Constant(3)) src0=Constant(3),
src1=Addition(Input(), Constant(3)),
src2=Addition(Input(), Constant(3)),
) )
assert not test_operation.is_linear assert not test_operation.is_linear
test_operation = MADS( test_operation = MADS(
Addition(Input(), Constant(3)), Constant(3), Addition(Input(), Constant(3)) src0=Addition(Input(), Constant(3)),
src1=Constant(3),
src2=Addition(Input(), Constant(3)),
) )
assert test_operation.is_linear assert test_operation.is_linear
test_operation = MADS( test_operation = MADS(
Addition(Input(), Constant(3)), Addition(Input(), Constant(3)), Constant(3) src0=Addition(Input(), Constant(3)),
src1=Addition(Input(), Constant(3)),
src2=Constant(3),
) )
assert test_operation.is_linear assert test_operation.is_linear
...@@ -381,6 +398,22 @@ class TestMADS: ...@@ -381,6 +398,22 @@ class TestMADS:
test_operation.is_add = False test_operation.is_add = False
assert not test_operation.is_add assert not test_operation.is_add
def test_mads_override_zero_on_src0_getter(self):
test_operation = MADS(override_zero_on_src0=False)
assert not test_operation.override_zero_on_src0
test_operation = MADS(override_zero_on_src0=True)
assert test_operation.override_zero_on_src0
def test_mads_override_zero_on_src0_setter(self):
test_operation = MADS(override_zero_on_src0=False)
test_operation.override_zero_on_src0 = True
assert test_operation.override_zero_on_src0
test_operation = MADS(override_zero_on_src0=True)
test_operation.override_zero_on_src0 = False
assert not test_operation.override_zero_on_src0
class TestRightShift: class TestRightShift:
"""Tests for RightShift class.""" """Tests for RightShift class."""
...@@ -556,6 +589,33 @@ class TestDepends: ...@@ -556,6 +589,33 @@ class TestDepends:
assert set(bfly1.inputs_required_for_output(1)) == {0, 1} assert set(bfly1.inputs_required_for_output(1)) == {0, 1}
class TestDontCare:
def test_create_sfg_with_dontcare(self):
i1 = Input()
dc = DontCare()
a = Addition(i1, dc)
o = Output(a)
sfg = SFG([i1], [o])
assert sfg.output_count == 1
assert sfg.input_count == 1
assert sfg.evaluate_output(0, [0]) == 0
assert sfg.evaluate_output(0, [1]) == 1
def test_dontcare_latency_getter(self):
test_operation = DontCare()
assert test_operation.latency == 0
def test_dontcare_repr(self):
test_operation = DontCare()
assert repr(test_operation) == "DontCare()"
def test_dontcare_str(self):
test_operation = DontCare()
assert str(test_operation) == "dontcare"
class TestSink: class TestSink:
def test_create_sfg_with_sink(self): def test_create_sfg_with_sink(self):
bfly = Butterfly() bfly = Butterfly()
......
...@@ -644,7 +644,7 @@ class TestRadix2FFT: ...@@ -644,7 +644,7 @@ class TestRadix2FFT:
class TestLdltMatrixInverse: class TestLdltMatrixInverse:
def test_1x1(self): def test_1x1(self):
sfg = ldlt_matrix_inverse(N=1, is_complex=False) sfg = ldlt_matrix_inverse(N=1)
assert len(sfg.inputs) == 1 assert len(sfg.inputs) == 1
assert len(sfg.outputs) == 1 assert len(sfg.outputs) == 1
...@@ -661,7 +661,7 @@ class TestLdltMatrixInverse: ...@@ -661,7 +661,7 @@ class TestLdltMatrixInverse:
assert np.isclose(res["0"], 0.2) assert np.isclose(res["0"], 0.2)
def test_2x2_simple_spd(self): def test_2x2_simple_spd(self):
sfg = ldlt_matrix_inverse(N=2, is_complex=False) sfg = ldlt_matrix_inverse(N=2)
assert len(sfg.inputs) == 3 assert len(sfg.inputs) == 3
assert len(sfg.outputs) == 3 assert len(sfg.outputs) == 3
...@@ -683,7 +683,7 @@ class TestLdltMatrixInverse: ...@@ -683,7 +683,7 @@ class TestLdltMatrixInverse:
assert np.isclose(res["2"], A_inv[1, 1]) assert np.isclose(res["2"], A_inv[1, 1])
def test_3x3_simple_spd(self): def test_3x3_simple_spd(self):
sfg = ldlt_matrix_inverse(N=3, is_complex=False) sfg = ldlt_matrix_inverse(N=3)
assert len(sfg.inputs) == 6 assert len(sfg.inputs) == 6
assert len(sfg.outputs) == 6 assert len(sfg.outputs) == 6
...@@ -717,7 +717,7 @@ class TestLdltMatrixInverse: ...@@ -717,7 +717,7 @@ class TestLdltMatrixInverse:
def test_5x5_random_spd(self): def test_5x5_random_spd(self):
N = 5 N = 5
sfg = ldlt_matrix_inverse(N=N, is_complex=False) sfg = ldlt_matrix_inverse(N=N)
assert len(sfg.inputs) == 15 assert len(sfg.inputs) == 15
assert len(sfg.outputs) == 15 assert len(sfg.outputs) == 15
...@@ -746,7 +746,7 @@ class TestLdltMatrixInverse: ...@@ -746,7 +746,7 @@ class TestLdltMatrixInverse:
def test_20x20_random_spd(self): def test_20x20_random_spd(self):
N = 20 N = 20
sfg = ldlt_matrix_inverse(N=N, is_complex=False) sfg = ldlt_matrix_inverse(N=N)
A = self._generate_random_spd_matrix(N) A = self._generate_random_spd_matrix(N)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment