Skip to content
Snippets Groups Projects
Commit 8c656f9f authored by Simon Bjurek's avatar Simon Bjurek
Browse files

Changed sfg generator to use DontCare.

parent a5cd70ea
No related branches found
No related tags found
1 merge request!470Add slack time scheduling, redid earliest deadline, added max-fan-out and hybrid scheduler, also added example
......@@ -1102,6 +1102,7 @@ class MAD(AbstractOperation):
class MADS(AbstractOperation):
__slots__ = (
"_is_add",
"_override_zero_on_src0",
"_src0",
"_src1",
"_src2",
......@@ -1111,6 +1112,7 @@ class MADS(AbstractOperation):
"_execution_time",
)
_is_add: Optional[bool]
_override_zero_on_src0: Optional[bool]
_src0: Optional[SignalSourceProvider]
_src1: Optional[SignalSourceProvider]
_src2: Optional[SignalSourceProvider]
......@@ -1124,6 +1126,7 @@ class MADS(AbstractOperation):
def __init__(
self,
is_add: Optional[bool] = True,
override_zero_on_src0: Optional[bool] = False,
src0: Optional[SignalSourceProvider] = None,
src1: Optional[SignalSourceProvider] = None,
src2: Optional[SignalSourceProvider] = None,
......@@ -1143,13 +1146,23 @@ class MADS(AbstractOperation):
execution_time=execution_time,
)
self.set_param("is_add", is_add)
self.set_param("override_zero_on_src0", override_zero_on_src0)
@classmethod
def type_name(cls) -> TypeName:
return TypeName("mads")
def evaluate(self, a, b, c):
return a + b * c if self.is_add else a - b * c
if self.is_add:
if self.override_zero_on_src0:
return b * c
else:
return a + b * c
else:
if self.override_zero_on_src0:
return -b * c
else:
return a - b * c
@property
def is_add(self) -> bool:
......@@ -1161,11 +1174,21 @@ class MADS(AbstractOperation):
"""Set if operation is an addition."""
self.set_param("is_add", is_add)
@property
def override_zero_on_src0(self) -> bool:
"""Get if operation is overriding a zero on port src0."""
return self.param("override_zero_on_src0")
@override_zero_on_src0.setter
def override_zero_on_src0(self, override_zero_on_src0: bool) -> None:
"""Set if operation is overriding a zero on port src0."""
self.set_param("override_zero_on_src0", override_zero_on_src0)
@property
def is_linear(self) -> bool:
return (
self.input(0).connected_source.operation.is_constant
or self.input(1).connected_source.operation.is_constant
self.input(1).connected_source.operation.is_constant
or self.input(2).connected_source.operation.is_constant
)
def swap_io(self) -> None:
......@@ -1598,6 +1621,51 @@ class Shift(AbstractOperation):
self.set_param("value", value)
class DontCare(AbstractOperation):
r"""
Dont-care operation
Used for ignoring the input to another operation and thus avoiding dangling input nodes.
Parameters
----------
name : Name, optional
Operation name.
"""
__slots__ = "_name"
_name: Name
is_linear = True
def __init__(self, name: Name = ""):
"""Construct a DontCare operation."""
super().__init__(
input_count=0,
output_count=1,
name=name,
latency_offsets={"out0": 0},
)
@classmethod
def type_name(cls) -> TypeName:
return TypeName("dontcare")
def evaluate(self):
return 0
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
def __repr__(self) -> str:
return "DontCare()"
def __str__(self) -> str:
return "dontcare"
class Sink(AbstractOperation):
r"""
Sink operation.
......
......@@ -12,9 +12,8 @@ from b_asic.core_operations import (
MADS,
Addition,
Butterfly,
ComplexConjugate,
Constant,
ConstantMultiplication,
DontCare,
Name,
Reciprocal,
SymmetricTwoportAdaptor,
......@@ -436,7 +435,19 @@ def radix_2_dif_fft(points: int) -> SFG:
return SFG(inputs=inputs, outputs=outputs)
def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
def ldlt_matrix_inverse(N: int) -> SFG:
"""Generates an SFG for the LDLT matrix inverse algorithm.
Parameters
----------
N : int
Dimension of the square input matrix.
Returns
-------
SFG
Signal Flow Graph
"""
inputs = []
A = [[None for _ in range(N)] for _ in range(N)]
for i in range(N):
......@@ -457,7 +468,7 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
# R*di*R^T factorization
for i in range(N):
for k in range(i):
D[i] = MADS(False, D[i], M[k][i], R[k][i])
D[i] = MADS(False, False, D[i], M[k][i], R[k][i])
D_inv[i] = Reciprocal(D[i])
......@@ -465,14 +476,14 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
R[i][j] = A[i][j]
for k in range(i):
R[i][j] = MADS(False, R[i][j], M[k][i], R[k][j])
R[i][j] = MADS(False, False, R[i][j], M[k][i], R[k][j])
if is_complex:
M[i][j] = ComplexConjugate(R[i][j])
else:
M[i][j] = R[i][j]
# if is_complex:
# M[i][j] = ComplexConjugate(R[i][j])
# else:
M[i][j] = R[i][j]
R[i][j] = MADS(True, Constant(0, name="0"), R[i][j], D_inv[i])
R[i][j] = MADS(True, True, DontCare(), R[i][j], D_inv[i])
# back substitution
A_inv = [[None for _ in range(N)] for _ in range(N)]
......@@ -481,14 +492,16 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG:
for j in reversed(range(i + 1)):
for k in reversed(range(j + 1, N)):
if k == N - 1 and i != j:
A_inv[j][i] = MADS(
False, Constant(0, name="0"), R[j][k], A_inv[i][k]
)
A_inv[j][i] = MADS(False, True, DontCare(), R[j][k], A_inv[i][k])
else:
if A_inv[i][k]:
A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[i][k])
A_inv[j][i] = MADS(
False, False, A_inv[j][i], R[j][k], A_inv[i][k]
)
else:
A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[k][i])
A_inv[j][i] = MADS(
False, False, A_inv[j][i], R[j][k], A_inv[k][i]
)
outputs = []
for i in range(N):
......
......@@ -5,6 +5,7 @@ import pytest
from b_asic import (
MAD,
MADS,
SFG,
Absolute,
Addition,
AddSub,
......@@ -13,11 +14,13 @@ from b_asic import (
Constant,
ConstantMultiplication,
Division,
DontCare,
Input,
LeftShift,
Max,
Min,
Multiplication,
Output,
Reciprocal,
RightShift,
Shift,
......@@ -343,19 +346,33 @@ class TestMADS:
test_operation = MADS(is_add=True)
assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -1 + 14j
def test_mads_zero_override(self):
test_operation = MADS(is_add=True, override_zero_on_src0=True)
assert test_operation.evaluate_output(0, [1, 1, 1]) == 1
def test_mads_sub_zero_override(self):
test_operation = MADS(is_add=False, override_zero_on_src0=True)
assert test_operation.evaluate_output(0, [1, 1, 1]) == -1
def test_mads_is_linear(self):
test_operation = MADS(
Constant(3), Addition(Input(), Constant(3)), Addition(Input(), Constant(3))
src0=Constant(3),
src1=Addition(Input(), Constant(3)),
src2=Addition(Input(), Constant(3)),
)
assert not test_operation.is_linear
test_operation = MADS(
Addition(Input(), Constant(3)), Constant(3), Addition(Input(), Constant(3))
src0=Addition(Input(), Constant(3)),
src1=Constant(3),
src2=Addition(Input(), Constant(3)),
)
assert test_operation.is_linear
test_operation = MADS(
Addition(Input(), Constant(3)), Addition(Input(), Constant(3)), Constant(3)
src0=Addition(Input(), Constant(3)),
src1=Addition(Input(), Constant(3)),
src2=Constant(3),
)
assert test_operation.is_linear
......@@ -381,6 +398,22 @@ class TestMADS:
test_operation.is_add = False
assert not test_operation.is_add
def test_mads_override_zero_on_src0_getter(self):
test_operation = MADS(override_zero_on_src0=False)
assert not test_operation.override_zero_on_src0
test_operation = MADS(override_zero_on_src0=True)
assert test_operation.override_zero_on_src0
def test_mads_override_zero_on_src0_setter(self):
test_operation = MADS(override_zero_on_src0=False)
test_operation.override_zero_on_src0 = True
assert test_operation.override_zero_on_src0
test_operation = MADS(override_zero_on_src0=True)
test_operation.override_zero_on_src0 = False
assert not test_operation.override_zero_on_src0
class TestRightShift:
"""Tests for RightShift class."""
......@@ -556,6 +589,33 @@ class TestDepends:
assert set(bfly1.inputs_required_for_output(1)) == {0, 1}
class TestDontCare:
def test_create_sfg_with_dontcare(self):
i1 = Input()
dc = DontCare()
a = Addition(i1, dc)
o = Output(a)
sfg = SFG([i1], [o])
assert sfg.output_count == 1
assert sfg.input_count == 1
assert sfg.evaluate_output(0, [0]) == 0
assert sfg.evaluate_output(0, [1]) == 1
def test_dontcare_latency_getter(self):
test_operation = DontCare()
assert test_operation.latency == 0
def test_dontcare_repr(self):
test_operation = DontCare()
assert repr(test_operation) == "DontCare()"
def test_dontcare_str(self):
test_operation = DontCare()
assert str(test_operation) == "dontcare"
class TestSink:
def test_create_sfg_with_sink(self):
bfly = Butterfly()
......
......@@ -644,7 +644,7 @@ class TestRadix2FFT:
class TestLdltMatrixInverse:
def test_1x1(self):
sfg = ldlt_matrix_inverse(N=1, is_complex=False)
sfg = ldlt_matrix_inverse(N=1)
assert len(sfg.inputs) == 1
assert len(sfg.outputs) == 1
......@@ -661,7 +661,7 @@ class TestLdltMatrixInverse:
assert np.isclose(res["0"], 0.2)
def test_2x2_simple_spd(self):
sfg = ldlt_matrix_inverse(N=2, is_complex=False)
sfg = ldlt_matrix_inverse(N=2)
assert len(sfg.inputs) == 3
assert len(sfg.outputs) == 3
......@@ -683,7 +683,7 @@ class TestLdltMatrixInverse:
assert np.isclose(res["2"], A_inv[1, 1])
def test_3x3_simple_spd(self):
sfg = ldlt_matrix_inverse(N=3, is_complex=False)
sfg = ldlt_matrix_inverse(N=3)
assert len(sfg.inputs) == 6
assert len(sfg.outputs) == 6
......@@ -717,7 +717,7 @@ class TestLdltMatrixInverse:
def test_5x5_random_spd(self):
N = 5
sfg = ldlt_matrix_inverse(N=N, is_complex=False)
sfg = ldlt_matrix_inverse(N=N)
assert len(sfg.inputs) == 15
assert len(sfg.outputs) == 15
......@@ -746,7 +746,7 @@ class TestLdltMatrixInverse:
def test_20x20_random_spd(self):
N = 20
sfg = ldlt_matrix_inverse(N=N, is_complex=False)
sfg = ldlt_matrix_inverse(N=N)
A = self._generate_random_spd_matrix(N)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment