Skip to content
Snippets Groups Projects
Commit eb1349d2 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Add *Shift-operations

parent 2ce86e81
No related branches found
No related tags found
1 merge request!404Add *Shift-operations
Pipeline #97747 passed
......@@ -783,7 +783,7 @@ class SymmetricTwoportAdaptor(AbstractOperation):
latency_offsets=latency_offsets,
execution_time=execution_time,
)
self.set_param("value", value)
self.value = value
@classmethod
def type_name(cls) -> TypeName:
......@@ -801,7 +801,10 @@ class SymmetricTwoportAdaptor(AbstractOperation):
@value.setter
def value(self, value: Num) -> None:
"""Set the constant value of this operation."""
self.set_param("value", value)
if -1 <= value <= 1:
self.set_param("value", value)
else:
raise ValueError('value must be between -1 and 1 (inclusive)')
def swap_io(self) -> None:
# Swap inputs and outputs and change sign of coefficient
......@@ -852,3 +855,233 @@ class Reciprocal(AbstractOperation):
def evaluate(self, a):
return 1 / a
class RightShift(AbstractOperation):
r"""
Arithmetic right-shift operation.
Shifts the input to the right assuming a fixed-point representation, so
a multiplication by a power of two.
.. math:: y = x \gg \text{value} = 2^{-\text{value}}x \text{ where value} \geq 0
Parameters
----------
value : int
Number of bits to shift right.
src0 : :class:`~b_asic.port.SignalSourceProvider`, optional
The signal to shift right.
name : Name, optional
Operation name.
latency : int, optional
Operation latency (delay from input to output in time units).
latency_offsets : dict[str, int], optional
Used if input arrives later than when the operator starts, e.g.,
``{"in0": 0`` which corresponds to *src0* arriving one time unit after the
operator starts. If not provided and *latency* is provided, set to zero.
execution_time : int, optional
Operation execution time (time units before operator can be reused).
See Also
--------
LeftShift
Shift
"""
is_linear = True
def __init__(
self,
value: int = 0,
src0: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct a RightShift operation with the given value."""
super().__init__(
input_count=1,
output_count=1,
name=Name(name),
input_sources=[src0],
latency=latency,
latency_offsets=latency_offsets,
execution_time=execution_time,
)
self.value = value
@classmethod
def type_name(cls) -> TypeName:
return TypeName("rshift")
def evaluate(self, a):
return a * 2 ** (-self.param("value"))
@property
def value(self) -> int:
"""Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: int) -> None:
"""Set the constant value of this operation."""
if not isinstance(value, int):
raise TypeError("value must be an int")
if value < 0:
raise ValueError("value must be non-negative")
self.set_param("value", value)
class LeftShift(AbstractOperation):
r"""
Arithmetic left-shift operation.
Shifts the input to the left assuming a fixed-point representation, so
a multiplication by a power of two.
.. math:: y = x \ll \text{value} = 2^{\text{value}}x \text{ where value} \geq 0
Parameters
----------
value : int
Number of bits to shift left.
src0 : :class:`~b_asic.port.SignalSourceProvider`, optional
The signal to shift left.
name : Name, optional
Operation name.
latency : int, optional
Operation latency (delay from input to output in time units).
latency_offsets : dict[str, int], optional
Used if input arrives later than when the operator starts, e.g.,
``{"in0": 0`` which corresponds to *src0* arriving one time unit after the
operator starts. If not provided and *latency* is provided, set to zero.
execution_time : int, optional
Operation execution time (time units before operator can be reused).
See Also
--------
RightShift
Shift
"""
is_linear = True
def __init__(
self,
value: int = 0,
src0: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct a RightShift operation with the given value."""
super().__init__(
input_count=1,
output_count=1,
name=Name(name),
input_sources=[src0],
latency=latency,
latency_offsets=latency_offsets,
execution_time=execution_time,
)
self.value = value
@classmethod
def type_name(cls) -> TypeName:
return TypeName("lshift")
def evaluate(self, a):
return a * 2 ** (self.param("value"))
@property
def value(self) -> int:
"""Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: int) -> None:
"""Set the constant value of this operation."""
if not isinstance(value, int):
raise TypeError("value must be an int")
if value < 0:
raise ValueError("value must be non-negative")
self.set_param("value", value)
class Shift(AbstractOperation):
r"""
Arithmetic shift operation.
Shifts the input to the left or right assuming a fixed-point representation, so
a multiplication by a power of two. By definition a positive value is a shift to
the left.
.. math:: y = x \ll \text{value} = 2^{\text{value}}x
Parameters
----------
value : int
Number of bits to shift. Positive *value* shifts to the left.
src0 : :class:`~b_asic.port.SignalSourceProvider`, optional
The signal to shift.
name : Name, optional
Operation name.
latency : int, optional
Operation latency (delay from input to output in time units).
latency_offsets : dict[str, int], optional
Used if input arrives later than when the operator starts, e.g.,
``{"in0": 0`` which corresponds to *src0* arriving one time unit after the
operator starts. If not provided and *latency* is provided, set to zero.
execution_time : int, optional
Operation execution time (time units before operator can be reused).
See Also
--------
LeftShift
RightShift
"""
is_linear = True
def __init__(
self,
value: int = 0,
src0: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct a Shift operation with the given value."""
super().__init__(
input_count=1,
output_count=1,
name=Name(name),
input_sources=[src0],
latency=latency,
latency_offsets=latency_offsets,
execution_time=execution_time,
)
self.value = value
@classmethod
def type_name(cls) -> TypeName:
return TypeName("shift")
def evaluate(self, a):
return a * 2 ** (self.param("value"))
@property
def value(self) -> int:
"""Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: int) -> None:
"""Set the constant value of this operation."""
if not isinstance(value, int):
raise TypeError("value must be an int")
self.set_param("value", value)
"""B-ASIC test suite for the core operations."""
import pytest
from b_asic import (
Absolute,
......@@ -9,10 +10,13 @@ from b_asic import (
Constant,
ConstantMultiplication,
Division,
LeftShift,
Max,
Min,
Multiplication,
Reciprocal,
RightShift,
Shift,
SquareRoot,
Subtraction,
SymmetricTwoportAdaptor,
......@@ -25,6 +29,7 @@ class TestConstant:
def test_constant_positive(self):
test_operation = Constant(3)
assert test_operation.evaluate_output(0, []) == 3
assert test_operation.value == 3
def test_constant_negative(self):
test_operation = Constant(-3)
......@@ -83,6 +88,7 @@ class TestAddSub:
def test_addition_negative(self):
test_operation = AddSub(is_add=True)
assert test_operation.evaluate_output(0, [-3, -5]) == -8
assert test_operation.is_add
def test_addition_complex(self):
test_operation = AddSub(is_add=True)
......@@ -91,6 +97,7 @@ class TestAddSub:
def test_addsub_subtraction_positive(self):
test_operation = AddSub(is_add=False)
assert test_operation.evaluate_output(0, [5, 3]) == 2
assert not test_operation.is_add
def test_addsub_subtraction_negative(self):
test_operation = AddSub(is_add=False)
......@@ -214,6 +221,7 @@ class TestConstantMultiplication:
def test_constantmultiplication_positive(self):
test_operation = ConstantMultiplication(5)
assert test_operation.evaluate_output(0, [20]) == 100
assert test_operation.value == 5
def test_constantmultiplication_negative(self):
test_operation = ConstantMultiplication(5)
......@@ -224,6 +232,101 @@ class TestConstantMultiplication:
assert test_operation.evaluate_output(0, [3 + 4j]) == 1 + 18j
class TestRightShift:
"""Tests for RightShift class."""
def test_rightshift_positive(self):
test_operation = RightShift(2)
assert test_operation.evaluate_output(0, [20]) == 5
assert test_operation.value == 2
def test_rightshift_negative(self):
test_operation = RightShift(2)
assert test_operation.evaluate_output(0, [-5]) == -1.25
def test_rightshift_complex(self):
test_operation = RightShift(2)
assert test_operation.evaluate_output(0, [2 + 1j]) == 0.5 + 0.25j
def test_rightshift_errors(self):
with pytest.raises(TypeError, match="value must be an int"):
_ = RightShift(0.5)
test_operation = RightShift(0)
with pytest.raises(TypeError, match="value must be an int"):
test_operation.value = 0.5
with pytest.raises(ValueError, match="value must be non-negative"):
_ = RightShift(-1)
test_operation = RightShift(0)
with pytest.raises(ValueError, match="value must be non-negative"):
test_operation.value = -1
class TestLeftShift:
"""Tests for LeftShift class."""
def test_leftshift_positive(self):
test_operation = LeftShift(2)
assert test_operation.evaluate_output(0, [5]) == 20
assert test_operation.value == 2
def test_leftshift_negative(self):
test_operation = LeftShift(2)
assert test_operation.evaluate_output(0, [-5]) == -20
def test_leftshift_complex(self):
test_operation = LeftShift(2)
assert test_operation.evaluate_output(0, [0.5 + 0.25j]) == 2 + 1j
def test_leftshift_errors(self):
with pytest.raises(TypeError, match="value must be an int"):
_ = LeftShift(0.5)
test_operation = LeftShift(0)
with pytest.raises(TypeError, match="value must be an int"):
test_operation.value = 0.5
with pytest.raises(ValueError, match="value must be non-negative"):
_ = LeftShift(-1)
test_operation = LeftShift(0)
with pytest.raises(ValueError, match="value must be non-negative"):
test_operation.value = -1
class TestShift:
"""Tests for Shift class."""
def test_shift_positive(self):
test_operation = Shift(2)
assert test_operation.evaluate_output(0, [5]) == 20
assert test_operation.value == 2
test_operation = Shift(-2)
assert test_operation.evaluate_output(0, [5]) == 1.25
assert test_operation.value == -2
def test_shift_negative(self):
test_operation = Shift(2)
assert test_operation.evaluate_output(0, [-5]) == -20
test_operation = Shift(-2)
assert test_operation.evaluate_output(0, [-5]) == -1.25
def test_shift_complex(self):
test_operation = Shift(2)
assert test_operation.evaluate_output(0, [0.5 + 0.25j]) == 2 + 1j
test_operation = Shift(-2)
assert test_operation.evaluate_output(0, [2 + 1j]) == 0.5 + 0.25j
@pytest.mark.parametrize("val", (-0.5, 0.5))
def test_leftshift_errors(self, val):
with pytest.raises(TypeError, match="value must be an int"):
_ = Shift(val)
test_operation = Shift(0)
with pytest.raises(TypeError, match="value must be an int"):
test_operation.value = val
class TestButterfly:
"""Tests for Butterfly class."""
......@@ -237,7 +340,7 @@ class TestButterfly:
assert test_operation.evaluate_output(0, [-2, -3]) == -5
assert test_operation.evaluate_output(1, [-2, -3]) == 1
def test_buttefly_complex(self):
def test_butterfly_complex(self):
test_operation = Butterfly()
assert test_operation.evaluate_output(0, [2 + 1j, 3 - 2j]) == 5 - 1j
assert test_operation.evaluate_output(1, [2 + 1j, 3 - 2j]) == -1 + 3j
......@@ -250,6 +353,7 @@ class TestSymmetricTwoportAdaptor:
test_operation = SymmetricTwoportAdaptor(0.5)
assert test_operation.evaluate_output(0, [2, 3]) == 3.5
assert test_operation.evaluate_output(1, [2, 3]) == 2.5
assert test_operation.value == 0.5
def test_symmetrictwoportadaptor_negative(self):
test_operation = SymmetricTwoportAdaptor(0.5)
......@@ -267,6 +371,13 @@ class TestSymmetricTwoportAdaptor:
test_operation.swap_io()
assert test_operation.value == -0.5
def test_symmetrictwoportadaptor_error(self):
with pytest.raises(ValueError, match="value must be between -1 and 1"):
_ = SymmetricTwoportAdaptor(-2)
test_operation = SymmetricTwoportAdaptor(0)
with pytest.raises(ValueError, match="value must be between -1 and 1"):
test_operation.value = 2
class TestReciprocal:
"""Tests for Absolute class."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment