diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 00f07a9592ea7a8242215c1859887d8600f3dbd8..090a80971fbe0fa7b6f339702f1251ca27c5bca3 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -783,7 +783,7 @@ class SymmetricTwoportAdaptor(AbstractOperation): latency_offsets=latency_offsets, execution_time=execution_time, ) - self.set_param("value", value) + self.value = value @classmethod def type_name(cls) -> TypeName: @@ -801,7 +801,10 @@ class SymmetricTwoportAdaptor(AbstractOperation): @value.setter def value(self, value: Num) -> None: """Set the constant value of this operation.""" - self.set_param("value", value) + if -1 <= value <= 1: + self.set_param("value", value) + else: + raise ValueError('value must be between -1 and 1 (inclusive)') def swap_io(self) -> None: # Swap inputs and outputs and change sign of coefficient @@ -852,3 +855,233 @@ class Reciprocal(AbstractOperation): def evaluate(self, a): return 1 / a + + +class RightShift(AbstractOperation): + r""" + Arithmetic right-shift operation. + + Shifts the input to the right assuming a fixed-point representation, so + a multiplication by a power of two. + + .. math:: y = x \gg \text{value} = 2^{-\text{value}}x \text{ where value} \geq 0 + + Parameters + ---------- + value : int + Number of bits to shift right. + src0 : :class:`~b_asic.port.SignalSourceProvider`, optional + The signal to shift right. + name : Name, optional + Operation name. + latency : int, optional + Operation latency (delay from input to output in time units). + latency_offsets : dict[str, int], optional + Used if input arrives later than when the operator starts, e.g., + ``{"in0": 0`` which corresponds to *src0* arriving one time unit after the + operator starts. If not provided and *latency* is provided, set to zero. + execution_time : int, optional + Operation execution time (time units before operator can be reused). + + See Also + -------- + LeftShift + Shift + """ + + is_linear = True + + def __init__( + self, + value: int = 0, + src0: Optional[SignalSourceProvider] = None, + name: Name = Name(""), + latency: Optional[int] = None, + latency_offsets: Optional[Dict[str, int]] = None, + execution_time: Optional[int] = None, + ): + """Construct a RightShift operation with the given value.""" + super().__init__( + input_count=1, + output_count=1, + name=Name(name), + input_sources=[src0], + latency=latency, + latency_offsets=latency_offsets, + execution_time=execution_time, + ) + self.value = value + + @classmethod + def type_name(cls) -> TypeName: + return TypeName("rshift") + + def evaluate(self, a): + return a * 2 ** (-self.param("value")) + + @property + def value(self) -> int: + """Get the constant value of this operation.""" + return self.param("value") + + @value.setter + def value(self, value: int) -> None: + """Set the constant value of this operation.""" + if not isinstance(value, int): + raise TypeError("value must be an int") + if value < 0: + raise ValueError("value must be non-negative") + self.set_param("value", value) + + +class LeftShift(AbstractOperation): + r""" + Arithmetic left-shift operation. + + Shifts the input to the left assuming a fixed-point representation, so + a multiplication by a power of two. + + .. math:: y = x \ll \text{value} = 2^{\text{value}}x \text{ where value} \geq 0 + + Parameters + ---------- + value : int + Number of bits to shift left. + src0 : :class:`~b_asic.port.SignalSourceProvider`, optional + The signal to shift left. + name : Name, optional + Operation name. + latency : int, optional + Operation latency (delay from input to output in time units). + latency_offsets : dict[str, int], optional + Used if input arrives later than when the operator starts, e.g., + ``{"in0": 0`` which corresponds to *src0* arriving one time unit after the + operator starts. If not provided and *latency* is provided, set to zero. + execution_time : int, optional + Operation execution time (time units before operator can be reused). + + See Also + -------- + RightShift + Shift + """ + + is_linear = True + + def __init__( + self, + value: int = 0, + src0: Optional[SignalSourceProvider] = None, + name: Name = Name(""), + latency: Optional[int] = None, + latency_offsets: Optional[Dict[str, int]] = None, + execution_time: Optional[int] = None, + ): + """Construct a RightShift operation with the given value.""" + super().__init__( + input_count=1, + output_count=1, + name=Name(name), + input_sources=[src0], + latency=latency, + latency_offsets=latency_offsets, + execution_time=execution_time, + ) + self.value = value + + @classmethod + def type_name(cls) -> TypeName: + return TypeName("lshift") + + def evaluate(self, a): + return a * 2 ** (self.param("value")) + + @property + def value(self) -> int: + """Get the constant value of this operation.""" + return self.param("value") + + @value.setter + def value(self, value: int) -> None: + """Set the constant value of this operation.""" + if not isinstance(value, int): + raise TypeError("value must be an int") + if value < 0: + raise ValueError("value must be non-negative") + self.set_param("value", value) + + +class Shift(AbstractOperation): + r""" + Arithmetic shift operation. + + Shifts the input to the left or right assuming a fixed-point representation, so + a multiplication by a power of two. By definition a positive value is a shift to + the left. + + .. math:: y = x \ll \text{value} = 2^{\text{value}}x + + Parameters + ---------- + value : int + Number of bits to shift. Positive *value* shifts to the left. + src0 : :class:`~b_asic.port.SignalSourceProvider`, optional + The signal to shift. + name : Name, optional + Operation name. + latency : int, optional + Operation latency (delay from input to output in time units). + latency_offsets : dict[str, int], optional + Used if input arrives later than when the operator starts, e.g., + ``{"in0": 0`` which corresponds to *src0* arriving one time unit after the + operator starts. If not provided and *latency* is provided, set to zero. + execution_time : int, optional + Operation execution time (time units before operator can be reused). + + See Also + -------- + LeftShift + RightShift + """ + + is_linear = True + + def __init__( + self, + value: int = 0, + src0: Optional[SignalSourceProvider] = None, + name: Name = Name(""), + latency: Optional[int] = None, + latency_offsets: Optional[Dict[str, int]] = None, + execution_time: Optional[int] = None, + ): + """Construct a Shift operation with the given value.""" + super().__init__( + input_count=1, + output_count=1, + name=Name(name), + input_sources=[src0], + latency=latency, + latency_offsets=latency_offsets, + execution_time=execution_time, + ) + self.value = value + + @classmethod + def type_name(cls) -> TypeName: + return TypeName("shift") + + def evaluate(self, a): + return a * 2 ** (self.param("value")) + + @property + def value(self) -> int: + """Get the constant value of this operation.""" + return self.param("value") + + @value.setter + def value(self, value: int) -> None: + """Set the constant value of this operation.""" + if not isinstance(value, int): + raise TypeError("value must be an int") + self.set_param("value", value) diff --git a/test/test_core_operations.py b/test/test_core_operations.py index b53d588ee38f553d0181d7eb3b3c209769e01a80..40f15f2fa04b5750a5b18899ffc7d76cb17a8ae1 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -1,4 +1,5 @@ """B-ASIC test suite for the core operations.""" +import pytest from b_asic import ( Absolute, @@ -9,10 +10,13 @@ from b_asic import ( Constant, ConstantMultiplication, Division, + LeftShift, Max, Min, Multiplication, Reciprocal, + RightShift, + Shift, SquareRoot, Subtraction, SymmetricTwoportAdaptor, @@ -25,6 +29,7 @@ class TestConstant: def test_constant_positive(self): test_operation = Constant(3) assert test_operation.evaluate_output(0, []) == 3 + assert test_operation.value == 3 def test_constant_negative(self): test_operation = Constant(-3) @@ -83,6 +88,7 @@ class TestAddSub: def test_addition_negative(self): test_operation = AddSub(is_add=True) assert test_operation.evaluate_output(0, [-3, -5]) == -8 + assert test_operation.is_add def test_addition_complex(self): test_operation = AddSub(is_add=True) @@ -91,6 +97,7 @@ class TestAddSub: def test_addsub_subtraction_positive(self): test_operation = AddSub(is_add=False) assert test_operation.evaluate_output(0, [5, 3]) == 2 + assert not test_operation.is_add def test_addsub_subtraction_negative(self): test_operation = AddSub(is_add=False) @@ -214,6 +221,7 @@ class TestConstantMultiplication: def test_constantmultiplication_positive(self): test_operation = ConstantMultiplication(5) assert test_operation.evaluate_output(0, [20]) == 100 + assert test_operation.value == 5 def test_constantmultiplication_negative(self): test_operation = ConstantMultiplication(5) @@ -224,6 +232,101 @@ class TestConstantMultiplication: assert test_operation.evaluate_output(0, [3 + 4j]) == 1 + 18j +class TestRightShift: + """Tests for RightShift class.""" + + def test_rightshift_positive(self): + test_operation = RightShift(2) + assert test_operation.evaluate_output(0, [20]) == 5 + assert test_operation.value == 2 + + def test_rightshift_negative(self): + test_operation = RightShift(2) + assert test_operation.evaluate_output(0, [-5]) == -1.25 + + def test_rightshift_complex(self): + test_operation = RightShift(2) + assert test_operation.evaluate_output(0, [2 + 1j]) == 0.5 + 0.25j + + def test_rightshift_errors(self): + with pytest.raises(TypeError, match="value must be an int"): + _ = RightShift(0.5) + test_operation = RightShift(0) + with pytest.raises(TypeError, match="value must be an int"): + test_operation.value = 0.5 + + with pytest.raises(ValueError, match="value must be non-negative"): + _ = RightShift(-1) + test_operation = RightShift(0) + with pytest.raises(ValueError, match="value must be non-negative"): + test_operation.value = -1 + + +class TestLeftShift: + """Tests for LeftShift class.""" + + def test_leftshift_positive(self): + test_operation = LeftShift(2) + assert test_operation.evaluate_output(0, [5]) == 20 + assert test_operation.value == 2 + + def test_leftshift_negative(self): + test_operation = LeftShift(2) + assert test_operation.evaluate_output(0, [-5]) == -20 + + def test_leftshift_complex(self): + test_operation = LeftShift(2) + assert test_operation.evaluate_output(0, [0.5 + 0.25j]) == 2 + 1j + + def test_leftshift_errors(self): + with pytest.raises(TypeError, match="value must be an int"): + _ = LeftShift(0.5) + test_operation = LeftShift(0) + with pytest.raises(TypeError, match="value must be an int"): + test_operation.value = 0.5 + + with pytest.raises(ValueError, match="value must be non-negative"): + _ = LeftShift(-1) + test_operation = LeftShift(0) + with pytest.raises(ValueError, match="value must be non-negative"): + test_operation.value = -1 + + +class TestShift: + """Tests for Shift class.""" + + def test_shift_positive(self): + test_operation = Shift(2) + assert test_operation.evaluate_output(0, [5]) == 20 + assert test_operation.value == 2 + + test_operation = Shift(-2) + assert test_operation.evaluate_output(0, [5]) == 1.25 + assert test_operation.value == -2 + + def test_shift_negative(self): + test_operation = Shift(2) + assert test_operation.evaluate_output(0, [-5]) == -20 + + test_operation = Shift(-2) + assert test_operation.evaluate_output(0, [-5]) == -1.25 + + def test_shift_complex(self): + test_operation = Shift(2) + assert test_operation.evaluate_output(0, [0.5 + 0.25j]) == 2 + 1j + + test_operation = Shift(-2) + assert test_operation.evaluate_output(0, [2 + 1j]) == 0.5 + 0.25j + + @pytest.mark.parametrize("val", (-0.5, 0.5)) + def test_leftshift_errors(self, val): + with pytest.raises(TypeError, match="value must be an int"): + _ = Shift(val) + test_operation = Shift(0) + with pytest.raises(TypeError, match="value must be an int"): + test_operation.value = val + + class TestButterfly: """Tests for Butterfly class.""" @@ -237,7 +340,7 @@ class TestButterfly: assert test_operation.evaluate_output(0, [-2, -3]) == -5 assert test_operation.evaluate_output(1, [-2, -3]) == 1 - def test_buttefly_complex(self): + def test_butterfly_complex(self): test_operation = Butterfly() assert test_operation.evaluate_output(0, [2 + 1j, 3 - 2j]) == 5 - 1j assert test_operation.evaluate_output(1, [2 + 1j, 3 - 2j]) == -1 + 3j @@ -250,6 +353,7 @@ class TestSymmetricTwoportAdaptor: test_operation = SymmetricTwoportAdaptor(0.5) assert test_operation.evaluate_output(0, [2, 3]) == 3.5 assert test_operation.evaluate_output(1, [2, 3]) == 2.5 + assert test_operation.value == 0.5 def test_symmetrictwoportadaptor_negative(self): test_operation = SymmetricTwoportAdaptor(0.5) @@ -267,6 +371,13 @@ class TestSymmetricTwoportAdaptor: test_operation.swap_io() assert test_operation.value == -0.5 + def test_symmetrictwoportadaptor_error(self): + with pytest.raises(ValueError, match="value must be between -1 and 1"): + _ = SymmetricTwoportAdaptor(-2) + test_operation = SymmetricTwoportAdaptor(0) + with pytest.raises(ValueError, match="value must be between -1 and 1"): + test_operation.value = 2 + class TestReciprocal: """Tests for Absolute class."""