Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_core_operations.py 20.54 KiB
"""B-ASIC test suite for the core operations."""
import pytest
from b_asic import (
MAD,
MADS,
SFG,
Absolute,
Addition,
AddSub,
Butterfly,
ComplexConjugate,
Constant,
ConstantMultiplication,
Division,
DontCare,
Input,
LeftShift,
Max,
Min,
Multiplication,
Output,
Reciprocal,
RightShift,
Shift,
Sink,
SquareRoot,
Subtraction,
SymmetricTwoportAdaptor,
)
class TestConstant:
"""Tests for Constant class."""
def test_constant_positive(self):
test_operation = Constant(3)
assert test_operation.evaluate_output(0, []) == 3
assert test_operation.value == 3
def test_constant_negative(self):
test_operation = Constant(-3)
assert test_operation.evaluate_output(0, []) == -3
def test_constant_complex(self):
test_operation = Constant(3 + 4j)
assert test_operation.evaluate_output(0, []) == 3 + 4j
def test_constant_change_value(self):
test_operation = Constant(3)
assert test_operation.value == 3
test_operation.value = 4
assert test_operation.value == 4
def test_constant_repr(self):
test_operation = Constant(3)
assert repr(test_operation) == "Constant(3)"
def test_constant_str(self):
test_operation = Constant(3)
assert str(test_operation) == "3"
class TestAddition:
"""Tests for Addition class."""
def test_addition_positive(self):
test_operation = Addition()
assert test_operation.evaluate_output(0, [3, 5]) == 8
def test_addition_negative(self):
test_operation = Addition()
assert test_operation.evaluate_output(0, [-3, -5]) == -8
def test_addition_complex(self):
test_operation = Addition()
assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == 7 + 11j
class TestSubtraction:
"""Tests for Subtraction class."""
def test_subtraction_positive(self):
test_operation = Subtraction()
assert test_operation.evaluate_output(0, [5, 3]) == 2
def test_subtraction_negative(self):
test_operation = Subtraction()
assert test_operation.evaluate_output(0, [-5, -3]) == -2
def test_subtraction_complex(self):
test_operation = Subtraction()
assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == -1 - 1j
class TestAddSub:
"""Tests for AddSub class."""
def test_addsub_positive(self):
test_operation = AddSub(is_add=True)
assert test_operation.evaluate_output(0, [3, 5]) == 8
def test_addsub_negative(self):
test_operation = AddSub(is_add=True)
assert test_operation.evaluate_output(0, [-3, -5]) == -8
assert test_operation.is_add
def test_addsub_complex(self):
test_operation = AddSub(is_add=True)
assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == 7 + 11j
def test_addsub_subtraction_positive(self):
test_operation = AddSub(is_add=False)
assert test_operation.evaluate_output(0, [5, 3]) == 2
assert not test_operation.is_add
def test_addsub_subtraction_negative(self):
test_operation = AddSub(is_add=False)
assert test_operation.evaluate_output(0, [-5, -3]) == -2
def test_addsub_subtraction_complex(self):
test_operation = AddSub(is_add=False)
assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == -1 - 1j
def test_addsub_subtraction_is_swappable(self):
test_operation = AddSub(is_add=False)
assert not test_operation.is_swappable
test_operation = AddSub(is_add=True)
assert test_operation.is_swappable
def test_addsub_is_add_getter(self):
test_operation = AddSub(is_add=False)
assert not test_operation.is_add
test_operation = AddSub(is_add=True)
assert test_operation.is_add
def test_addsub_is_add_setter(self):
test_operation = AddSub(is_add=False)
test_operation.is_add = True
assert test_operation.is_add
test_operation = AddSub(is_add=True)
test_operation.is_add = False
assert not test_operation.is_add
class TestMultiplication:
"""Tests for Multiplication class."""
def test_multiplication_positive(self):
test_operation = Multiplication()
assert test_operation.evaluate_output(0, [5, 3]) == 15
def test_multiplication_negative(self):
test_operation = Multiplication()
assert test_operation.evaluate_output(0, [-5, -3]) == 15
def test_multiplication_complex(self):
test_operation = Multiplication()
assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == -18 + 38j
class TestDivision:
"""Tests for Division class."""
def test_division_positive(self):
test_operation = Division()
assert test_operation.evaluate_output(0, [30, 5]) == 6
def test_division_negative(self):
test_operation = Division()
assert test_operation.evaluate_output(0, [-30, -5]) == 6
def test_division_complex(self):
test_operation = Division()
assert test_operation.evaluate_output(0, [60 + 40j, 10 + 20j]) == 2.8 - 1.6j
def test_mads_is_linear(self):
test_operation = Division(Constant(3), Addition(Input(), Constant(3)))
assert not test_operation.is_linear
test_operation = Division(Addition(Input(), Constant(3)), Constant(3))
assert test_operation.is_linear
class TestSquareRoot:
"""Tests for SquareRoot class."""
def test_squareroot_positive(self):
test_operation = SquareRoot()
assert test_operation.evaluate_output(0, [36]) == 6
def test_squareroot_negative(self):
test_operation = SquareRoot()
assert test_operation.evaluate_output(0, [-36]) == 6j
def test_squareroot_complex(self):
test_operation = SquareRoot()
assert test_operation.evaluate_output(0, [48 + 64j]) == 8 + 4j
class TestComplexConjugate:
"""Tests for ComplexConjugate class."""
def test_complexconjugate_positive(self):
test_operation = ComplexConjugate()
assert test_operation.evaluate_output(0, [3 + 4j]) == 3 - 4j
def test_test_complexconjugate_negative(self):
test_operation = ComplexConjugate()
assert test_operation.evaluate_output(0, [-3 - 4j]) == -3 + 4j
class TestMax:
"""Tests for Max class."""
def test_max_positive(self):
test_operation = Max()
assert test_operation.evaluate_output(0, [30, 5]) == 30
def test_max_negative(self):
test_operation = Max()
assert test_operation.evaluate_output(0, [-30, -5]) == -5
class TestMin:
"""Tests for Min class."""
def test_min_positive(self):
test_operation = Min()
assert test_operation.evaluate_output(0, [30, 5]) == 5
def test_min_negative(self):
test_operation = Min()
assert test_operation.evaluate_output(0, [-30, -5]) == -30
def test_min_complex(self):
test_operation = Min()
with pytest.raises(
ValueError, match="core_operations.Min does not support complex numbers."
):
test_operation.evaluate_output(0, [-1 - 1j, 2 + 2j])
class TestAbsolute:
"""Tests for Absolute class."""
def test_absolute_positive(self):
test_operation = Absolute()
assert test_operation.evaluate_output(0, [30]) == 30
def test_absolute_negative(self):
test_operation = Absolute()
assert test_operation.evaluate_output(0, [-5]) == 5
def test_absolute_complex(self):
test_operation = Absolute()
assert test_operation.evaluate_output(0, [3 + 4j]) == 5.0
def test_max_complex(self):
test_operation = Max()
with pytest.raises(
ValueError, match="core_operations.Max does not support complex numbers."
):
test_operation.evaluate_output(0, [-1 - 1j, 2 + 2j])
class TestConstantMultiplication:
"""Tests for ConstantMultiplication class."""
def test_constantmultiplication_positive(self):
test_operation = ConstantMultiplication(5)
assert test_operation.evaluate_output(0, [20]) == 100
assert test_operation.value == 5
def test_constantmultiplication_negative(self):
test_operation = ConstantMultiplication(5)
assert test_operation.evaluate_output(0, [-5]) == -25
def test_constantmultiplication_complex(self):
test_operation = ConstantMultiplication(3 + 2j)
assert test_operation.evaluate_output(0, [3 + 4j]) == 1 + 18j
class TestMAD:
def test_mad_positive(self):
test_operation = MAD()
assert test_operation.evaluate_output(0, [1, 2, 3]) == 5
def test_mad_negative(self):
test_operation = MAD()
assert test_operation.evaluate_output(0, [-3, -5, -8]) == 7
def test_mad_complex(self):
test_operation = MAD()
assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -29 + 31j
def test_mad_is_linear(self):
test_operation = MAD(
Constant(3), Addition(Input(), Constant(3)), Addition(Input(), Constant(3))
)
assert test_operation.is_linear
test_operation = MAD(
Addition(Input(), Constant(3)), Constant(3), Addition(Input(), Constant(3))
)
assert test_operation.is_linear
test_operation = MAD(
Addition(Input(), Constant(3)), Addition(Input(), Constant(3)), Constant(3)
)
assert not test_operation.is_linear
def test_mad_swap_io(self):
test_operation = MAD()
assert test_operation.evaluate_output(0, [1, 2, 3]) == 5
test_operation.swap_io()
assert test_operation.evaluate_output(0, [1, 2, 3]) == 5
class TestMADS:
def test_mads_positive(self):
test_operation = MADS(is_add=False)
assert test_operation.evaluate_output(0, [1, 2, 3]) == -5
def test_mads_negative(self):
test_operation = MADS(is_add=False)
assert test_operation.evaluate_output(0, [-3, -5, -8]) == -43
def test_mads_complex(self):
test_operation = MADS(is_add=False)
assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == 7 - 2j
def test_mads_positive_add(self):
test_operation = MADS(is_add=True)
assert test_operation.evaluate_output(0, [1, 2, 3]) == 7
def test_mads_negative_add(self):
test_operation = MADS(is_add=True)
assert test_operation.evaluate_output(0, [-3, -5, -8]) == 37
def test_mads_complex_add(self):
test_operation = MADS(is_add=True)
assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -1 + 14j
def test_mads_zero_override(self):
test_operation = MADS(is_add=True, override_zero_on_src0=True)
assert test_operation.evaluate_output(0, [1, 1, 1]) == 1
def test_mads_sub_zero_override(self):
test_operation = MADS(is_add=False, override_zero_on_src0=True)
assert test_operation.evaluate_output(0, [1, 1, 1]) == -1
def test_mads_is_linear(self):
test_operation = MADS(
src0=Constant(3),
src1=Addition(Input(), Constant(3)),
src2=Addition(Input(), Constant(3)),
)
assert not test_operation.is_linear
test_operation = MADS(
src0=Addition(Input(), Constant(3)),
src1=Constant(3),
src2=Addition(Input(), Constant(3)),
)
assert test_operation.is_linear
test_operation = MADS(
src0=Addition(Input(), Constant(3)),
src1=Addition(Input(), Constant(3)),
src2=Constant(3),
)
assert test_operation.is_linear
def test_mads_swap_io(self):
test_operation = MADS(is_add=False)
assert test_operation.evaluate_output(0, [1, 2, 3]) == -5
test_operation.swap_io()
assert test_operation.evaluate_output(0, [1, 2, 3]) == -5
def test_mads_is_add_getter(self):
test_operation = MADS(is_add=False)
assert not test_operation.is_add
test_operation = MADS(is_add=True)
assert test_operation.is_add
def test_mads_is_add_setter(self):
test_operation = MADS(is_add=False)
test_operation.is_add = True
assert test_operation.is_add
test_operation = MADS(is_add=True)
test_operation.is_add = False
assert not test_operation.is_add
def test_mads_override_zero_on_src0_getter(self):
test_operation = MADS(override_zero_on_src0=False)
assert not test_operation.override_zero_on_src0
test_operation = MADS(override_zero_on_src0=True)
assert test_operation.override_zero_on_src0
def test_mads_override_zero_on_src0_setter(self):
test_operation = MADS(override_zero_on_src0=False)
test_operation.override_zero_on_src0 = True
assert test_operation.override_zero_on_src0
test_operation = MADS(override_zero_on_src0=True)
test_operation.override_zero_on_src0 = False
assert not test_operation.override_zero_on_src0
class TestRightShift:
"""Tests for RightShift class."""
def test_rightshift_positive(self):
test_operation = RightShift(2)
assert test_operation.evaluate_output(0, [20]) == 5
assert test_operation.value == 2
def test_rightshift_negative(self):
test_operation = RightShift(2)
assert test_operation.evaluate_output(0, [-5]) == -1.25
def test_rightshift_complex(self):
test_operation = RightShift(2)
assert test_operation.evaluate_output(0, [2 + 1j]) == 0.5 + 0.25j
def test_rightshift_errors(self):
with pytest.raises(TypeError, match="value must be an int"):
_ = RightShift(0.5)
test_operation = RightShift(0)
with pytest.raises(TypeError, match="value must be an int"):
test_operation.value = 0.5
with pytest.raises(ValueError, match="value must be non-negative"):
_ = RightShift(-1)
test_operation = RightShift(0)
with pytest.raises(ValueError, match="value must be non-negative"):
test_operation.value = -1
class TestLeftShift:
"""Tests for LeftShift class."""
def test_leftshift_positive(self):
test_operation = LeftShift(2)
assert test_operation.evaluate_output(0, [5]) == 20
assert test_operation.value == 2
def test_leftshift_negative(self):
test_operation = LeftShift(2)
assert test_operation.evaluate_output(0, [-5]) == -20
def test_leftshift_complex(self):
test_operation = LeftShift(2)
assert test_operation.evaluate_output(0, [0.5 + 0.25j]) == 2 + 1j
def test_leftshift_errors(self):
with pytest.raises(TypeError, match="value must be an int"):
_ = LeftShift(0.5)
test_operation = LeftShift(0)
with pytest.raises(TypeError, match="value must be an int"):
test_operation.value = 0.5
with pytest.raises(ValueError, match="value must be non-negative"):
_ = LeftShift(-1)
test_operation = LeftShift(0)
with pytest.raises(ValueError, match="value must be non-negative"):
test_operation.value = -1
class TestShift:
"""Tests for Shift class."""
def test_shift_positive(self):
test_operation = Shift(2)
assert test_operation.evaluate_output(0, [5]) == 20
assert test_operation.value == 2
test_operation = Shift(-2)
assert test_operation.evaluate_output(0, [5]) == 1.25
assert test_operation.value == -2
def test_shift_negative(self):
test_operation = Shift(2)
assert test_operation.evaluate_output(0, [-5]) == -20
test_operation = Shift(-2)
assert test_operation.evaluate_output(0, [-5]) == -1.25
def test_shift_complex(self):
test_operation = Shift(2)
assert test_operation.evaluate_output(0, [0.5 + 0.25j]) == 2 + 1j
test_operation = Shift(-2)
assert test_operation.evaluate_output(0, [2 + 1j]) == 0.5 + 0.25j
@pytest.mark.parametrize("val", (-0.5, 0.5))
def test_leftshift_errors(self, val):
with pytest.raises(TypeError, match="value must be an int"):
_ = Shift(val)
test_operation = Shift(0)
with pytest.raises(TypeError, match="value must be an int"):
test_operation.value = val
class TestButterfly:
"""Tests for Butterfly class."""
def test_butterfly_positive(self):
test_operation = Butterfly()
assert test_operation.evaluate_output(0, [2, 3]) == 5
assert test_operation.evaluate_output(1, [2, 3]) == -1
def test_butterfly_negative(self):
test_operation = Butterfly()
assert test_operation.evaluate_output(0, [-2, -3]) == -5
assert test_operation.evaluate_output(1, [-2, -3]) == 1
def test_butterfly_complex(self):
test_operation = Butterfly()
assert test_operation.evaluate_output(0, [2 + 1j, 3 - 2j]) == 5 - 1j
assert test_operation.evaluate_output(1, [2 + 1j, 3 - 2j]) == -1 + 3j
class TestSymmetricTwoportAdaptor:
"""Tests for SymmetricTwoportAdaptor class."""
def test_symmetrictwoportadaptor_positive(self):
test_operation = SymmetricTwoportAdaptor(0.5)
assert test_operation.evaluate_output(0, [2, 3]) == 3.5
assert test_operation.evaluate_output(1, [2, 3]) == 2.5
assert test_operation.value == 0.5
def test_symmetrictwoportadaptor_negative(self):
test_operation = SymmetricTwoportAdaptor(0.5)
assert test_operation.evaluate_output(0, [-2, -3]) == -3.5
assert test_operation.evaluate_output(1, [-2, -3]) == -2.5
def test_symmetrictwoportadaptor_complex(self):
test_operation = SymmetricTwoportAdaptor(0.5)
assert test_operation.evaluate_output(0, [2 + 1j, 3 - 2j]) == 3.5 - 3.5j
assert test_operation.evaluate_output(1, [2 + 1j, 3 - 2j]) == 2.5 - 0.5j
def test_symmetrictwoportadaptor_swap_io(self):
test_operation = SymmetricTwoportAdaptor(0.5)
assert test_operation.value == 0.5
test_operation.swap_io()
assert test_operation.value == -0.5
def test_symmetrictwoportadaptor_error(self):
with pytest.raises(ValueError, match="value must be between -1 and 1"):
_ = SymmetricTwoportAdaptor(-2)
test_operation = SymmetricTwoportAdaptor(0)
with pytest.raises(ValueError, match="value must be between -1 and 1"):
test_operation.value = 2
class TestReciprocal:
"""Tests for Absolute class."""
def test_reciprocal_positive(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [2]) == 0.5
def test_reciprocal_negative(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [-5]) == -0.2
def test_reciprocal_complex(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [1 + 1j]) == 0.5 - 0.5j
class TestDepends:
def test_depends_addition(self):
add1 = Addition()
assert set(add1.inputs_required_for_output(0)) == {0, 1}
def test_depends_butterfly(self):
bfly1 = Butterfly()
assert set(bfly1.inputs_required_for_output(0)) == {0, 1}
assert set(bfly1.inputs_required_for_output(1)) == {0, 1}
class TestDontCare:
def test_create_sfg_with_dontcare(self):
i1 = Input()
dc = DontCare()
a = Addition(i1, dc)
o = Output(a)
sfg = SFG([i1], [o])
assert sfg.output_count == 1
assert sfg.input_count == 1
assert sfg.evaluate_output(0, [0]) == 0
assert sfg.evaluate_output(0, [1]) == 1
def test_dontcare_latency_getter(self):
test_operation = DontCare()
assert test_operation.latency == 0
def test_dontcare_repr(self):
test_operation = DontCare()
assert repr(test_operation) == "DontCare()"
def test_dontcare_str(self):
test_operation = DontCare()
assert str(test_operation) == "dontcare"
class TestSink:
def test_create_sfg_with_sink(self):
bfly = Butterfly()
sfg = bfly.to_sfg()
s = Sink()
with pytest.warns(UserWarning, match="Output port out0 has been removed"):
sfg1 = sfg.replace_operation(s, "out0")
assert sfg1.output_count == 1
assert sfg1.input_count == 2
assert sfg.evaluate_output(1, [0, 1]) == sfg1.evaluate_output(0, [0, 1])
def test_sink_latency_getter(self):
test_operation = Sink()
assert test_operation.latency == 0
def test_sink_repr(self):
test_operation = Sink()
assert repr(test_operation) == "Sink()"
def test_sink_str(self):
test_operation = Sink()
assert str(test_operation) == "sink"