diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 67d6b07f5d0344224d6e08f760c4cff82d328a09..fd75359d32afe8eb9914dae699549d980b08d536 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -1170,9 +1170,9 @@ class MADS(AbstractOperation): def swap_io(self) -> None: self._input_ports = [ - self._input_ports[1], self._input_ports[0], self._input_ports[2], + self._input_ports[1], ] for i, p in enumerate(self._input_ports): p._index = i diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py index a0847de4943dca7d9f6c7515df7983971f52d622..0f7d26c50c5354553713d1655fc3d0dafc5ea70d 100644 --- a/b_asic/sfg_generators.py +++ b/b_asic/sfg_generators.py @@ -457,7 +457,7 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: # R*di*R^T factorization for i in range(N): for k in range(i): - D[i] = MADS(False, D[i], M[k][i], R[k][i], name="M1") + D[i] = MADS(False, D[i], M[k][i], R[k][i]) D_inv[i] = Reciprocal(D[i]) @@ -465,14 +465,14 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: R[i][j] = A[i][j] for k in range(i): - R[i][j] = MADS(False, R[i][j], M[k][i], R[k][j], name="M2") + R[i][j] = MADS(False, R[i][j], M[k][i], R[k][j]) if is_complex: M[i][j] = ComplexConjugate(R[i][j]) else: M[i][j] = R[i][j] - R[i][j] = MADS(True, Constant(0, name="0"), R[i][j], D_inv[i], name="M3") + R[i][j] = MADS(True, Constant(0, name="0"), R[i][j], D_inv[i]) # back substitution A_inv = [[None for _ in range(N)] for _ in range(N)] @@ -482,13 +482,11 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: for k in reversed(range(j + 1, N)): if k == N - 1 and i != j: A_inv[j][i] = MADS( - False, Constant(0, name="0"), R[j][k], A_inv[i][k], name="M4" + False, Constant(0, name="0"), R[j][k], A_inv[i][k] ) else: if A_inv[i][k]: - A_inv[j][i] = MADS( - False, A_inv[j][i], R[j][k], A_inv[i][k], name="M5" - ) + A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[i][k]) else: A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[k][i]) @@ -533,26 +531,6 @@ def _construct_dif_fft_stage( return ports -def _extract_diagonal(elems): - n = 0 - k = 0 - while k <= len(elems): - k += n + 1 - n += 1 - n -= 1 - k -= n + 1 - if k != len(elems): - return None - - diagonal = np.zeros(n) - index = 0 - for i in range(n): - diagonal[n] = elems[index] - index += i + 2 - - return diagonal - - def _get_bit_reversed_number(number: int, number_of_bits: int) -> int: reversed_number = 0 for i in range(number_of_bits): diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 1a5f367b07bfb9d9a3bf6d3d70f7b1a53901db90..328acad27a38b16f27c6a424d0e7af45f42be1d9 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -3,6 +3,8 @@ import pytest from b_asic import ( + MAD, + MADS, Absolute, Addition, AddSub, @@ -11,6 +13,7 @@ from b_asic import ( Constant, ConstantMultiplication, Division, + Input, LeftShift, Max, Min, @@ -47,6 +50,14 @@ class TestConstant: test_operation.value = 4 assert test_operation.value == 4 + def test_constant_repr(self): + test_operation = Constant(3) + assert repr(test_operation) == "Constant(3)" + + def test_constant_str(self): + test_operation = Constant(3) + assert str(test_operation) == "3" + class TestAddition: """Tests for Addition class.""" @@ -83,16 +94,16 @@ class TestSubtraction: class TestAddSub: """Tests for AddSub class.""" - def test_addition_positive(self): + def test_addsub_positive(self): test_operation = AddSub(is_add=True) assert test_operation.evaluate_output(0, [3, 5]) == 8 - def test_addition_negative(self): + def test_addsub_negative(self): test_operation = AddSub(is_add=True) assert test_operation.evaluate_output(0, [-3, -5]) == -8 assert test_operation.is_add - def test_addition_complex(self): + def test_addsub_complex(self): test_operation = AddSub(is_add=True) assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == 7 + 11j @@ -116,6 +127,22 @@ class TestAddSub: test_operation = AddSub(is_add=True) assert test_operation.is_swappable + def test_addsub_is_add_getter(self): + test_operation = AddSub(is_add=False) + assert not test_operation.is_add + + test_operation = AddSub(is_add=True) + assert test_operation.is_add + + def test_addsub_is_add_setter(self): + test_operation = AddSub(is_add=False) + test_operation.is_add = True + assert test_operation.is_add + + test_operation = AddSub(is_add=True) + test_operation.is_add = False + assert not test_operation.is_add + class TestMultiplication: """Tests for Multiplication class.""" @@ -148,6 +175,13 @@ class TestDivision: test_operation = Division() assert test_operation.evaluate_output(0, [60 + 40j, 10 + 20j]) == 2.8 - 1.6j + def test_mads_is_linear(self): + test_operation = Division(Constant(3), Addition(Input(), Constant(3))) + assert not test_operation.is_linear + + test_operation = Division(Addition(Input(), Constant(3)), Constant(3)) + assert test_operation.is_linear + class TestSquareRoot: """Tests for SquareRoot class.""" @@ -200,6 +234,13 @@ class TestMin: test_operation = Min() assert test_operation.evaluate_output(0, [-30, -5]) == -30 + def test_min_complex(self): + test_operation = Min() + with pytest.raises( + ValueError, match="core_operations.Min does not support complex numbers." + ): + test_operation.evaluate_output(0, [-1 - 1j, 2 + 2j]) + class TestAbsolute: """Tests for Absolute class.""" @@ -216,6 +257,13 @@ class TestAbsolute: test_operation = Absolute() assert test_operation.evaluate_output(0, [3 + 4j]) == 5.0 + def test_max_complex(self): + test_operation = Max() + with pytest.raises( + ValueError, match="core_operations.Max does not support complex numbers." + ): + test_operation.evaluate_output(0, [-1 - 1j, 2 + 2j]) + class TestConstantMultiplication: """Tests for ConstantMultiplication class.""" @@ -234,6 +282,106 @@ class TestConstantMultiplication: assert test_operation.evaluate_output(0, [3 + 4j]) == 1 + 18j +class TestMAD: + def test_mad_positive(self): + test_operation = MAD() + assert test_operation.evaluate_output(0, [1, 2, 3]) == 5 + + def test_mad_negative(self): + test_operation = MAD() + assert test_operation.evaluate_output(0, [-3, -5, -8]) == 7 + + def test_mad_complex(self): + test_operation = MAD() + assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -29 + 31j + + def test_mad_is_linear(self): + test_operation = MAD( + Constant(3), Addition(Input(), Constant(3)), Addition(Input(), Constant(3)) + ) + assert test_operation.is_linear + + test_operation = MAD( + Addition(Input(), Constant(3)), Constant(3), Addition(Input(), Constant(3)) + ) + assert test_operation.is_linear + + test_operation = MAD( + Addition(Input(), Constant(3)), Addition(Input(), Constant(3)), Constant(3) + ) + assert not test_operation.is_linear + + def test_mad_swap_io(self): + test_operation = MAD() + assert test_operation.evaluate_output(0, [1, 2, 3]) == 5 + test_operation.swap_io() + assert test_operation.evaluate_output(0, [1, 2, 3]) == 5 + + +class TestMADS: + def test_mads_positive(self): + test_operation = MADS(is_add=False) + assert test_operation.evaluate_output(0, [1, 2, 3]) == -5 + + def test_mads_negative(self): + test_operation = MADS(is_add=False) + assert test_operation.evaluate_output(0, [-3, -5, -8]) == -43 + + def test_mads_complex(self): + test_operation = MADS(is_add=False) + assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == 7 - 2j + + def test_mads_positive_add(self): + test_operation = MADS(is_add=True) + assert test_operation.evaluate_output(0, [1, 2, 3]) == 7 + + def test_mads_negative_add(self): + test_operation = MADS(is_add=True) + assert test_operation.evaluate_output(0, [-3, -5, -8]) == 37 + + def test_mads_complex_add(self): + test_operation = MADS(is_add=True) + assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -1 + 14j + + def test_mads_is_linear(self): + test_operation = MADS( + Constant(3), Addition(Input(), Constant(3)), Addition(Input(), Constant(3)) + ) + assert not test_operation.is_linear + + test_operation = MADS( + Addition(Input(), Constant(3)), Constant(3), Addition(Input(), Constant(3)) + ) + assert test_operation.is_linear + + test_operation = MADS( + Addition(Input(), Constant(3)), Addition(Input(), Constant(3)), Constant(3) + ) + assert test_operation.is_linear + + def test_mads_swap_io(self): + test_operation = MADS(is_add=False) + assert test_operation.evaluate_output(0, [1, 2, 3]) == -5 + test_operation.swap_io() + assert test_operation.evaluate_output(0, [1, 2, 3]) == -5 + + def test_mads_is_add_getter(self): + test_operation = MADS(is_add=False) + assert not test_operation.is_add + + test_operation = MADS(is_add=True) + assert test_operation.is_add + + def test_mads_is_add_setter(self): + test_operation = MADS(is_add=False) + test_operation.is_add = True + assert test_operation.is_add + + test_operation = MADS(is_add=True) + test_operation.is_add = False + assert not test_operation.is_add + + class TestRightShift: """Tests for RightShift class.""" @@ -420,3 +568,15 @@ class TestSink: assert sfg1.input_count == 2 assert sfg.evaluate_output(1, [0, 1]) == sfg1.evaluate_output(0, [0, 1]) + + def test_sink_latency_getter(self): + test_operation = Sink() + assert test_operation.latency == 0 + + def test_sink_repr(self): + test_operation = Sink() + assert repr(test_operation) == "Sink()" + + def test_sink_str(self): + test_operation = Sink() + assert str(test_operation) == "sink" diff --git a/test/test_sfg_generators.py b/test/test_sfg_generators.py index e3430bc40c9e79636f40e1e077c092c4b9944c04..3d876d7595c92132fe8f247c50719da9a5b385e7 100644 --- a/test/test_sfg_generators.py +++ b/test/test_sfg_generators.py @@ -807,9 +807,9 @@ class TestLdltMatrixInverse: A += (np.abs(min_eig) + 0.1) * np.eye(N) # ensure positive definiteness return A - def _generate_random_complex_spd_matrix(self, N: int) -> np.ndarray: - A = np.random.randn(N, N) + 1j * np.random.randn(N, N) - A = (A + A.conj().T) / 2 # ensure symmetric - min_eig = np.min(np.linalg.eigvals(A)) - A += (np.abs(min_eig) + 0.1) * np.eye(N) # ensure positive definiteness - return A + # def _generate_random_complex_spd_matrix(self, N: int) -> np.ndarray: + # A = np.random.randn(N, N) + 1j * np.random.randn(N, N) + # A = (A + A.conj().T) / 2 # ensure symmetric + # min_eig = np.min(np.linalg.eigvals(A)) + # A += (np.abs(min_eig) + 0.1) * np.eye(N) # ensure positive definiteness + # return A