Skip to content
Snippets Groups Projects
Commit 6e9b10ad authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Add Reciprocal operation

parent a43ef562
No related branches found
No related tags found
1 merge request!163Add Reciprocal operation
Pipeline #88832 passed
......@@ -614,8 +614,8 @@ class Butterfly(AbstractOperation):
.. math::
\begin{eqnarray}
y_0 = x_0 + x_1\\
y_1 = x_0 - x_1
y_0 & = & x_0 + x_1\\
y_1 & = & x_0 - x_1
\end{eqnarray}
"""
......@@ -692,8 +692,8 @@ class SymmetricTwoportAdaptor(AbstractOperation):
.. math::
\begin{eqnarray}
y_0 = x_1 + \text{value}\times\left(x_1 - x_0\right)\\
y_1 = x_0 + \text{value}\times\left(x_1 - x_0\right)
y_0 & = & x_1 + \text{value}\times\left(x_1 - x_0\right)\\
y_1 & = & x_0 + \text{value}\times\left(x_1 - x_0\right)
\end{eqnarray}
"""
......@@ -736,3 +736,39 @@ class SymmetricTwoportAdaptor(AbstractOperation):
def value(self, value: Number) -> None:
"""Set the constant value of this operation."""
self.set_param("value", value)
class Reciprocal(AbstractOperation):
r"""
Reciprocal operation.
Gives the reciprocal of its input.
.. math:: y = \frac{1}{x}
"""
def __init__(
self,
src0: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct an Reciprocal operation."""
super().__init__(
input_count=1,
output_count=1,
name=Name(name),
input_sources=[src0],
latency=latency,
latency_offsets=latency_offsets,
execution_time=execution_time,
)
@classmethod
def type_name(cls) -> TypeName:
return TypeName("rec")
def evaluate(self, a):
return 1 / a
......@@ -40,6 +40,7 @@ if TYPE_CHECKING:
ConstantMultiplication,
Division,
Multiplication,
Reciprocal,
Subtraction,
)
from b_asic.signal_flow_graph import SFG
......@@ -135,7 +136,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def __rtruediv__(
self, src: Union[SignalSourceProvider, Number]
) -> "Division":
) -> Union["Division", "Reciprocal"]:
"""
Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects.
......@@ -387,7 +388,7 @@ class Operation(GraphComponent, SignalSourceProvider):
self,
) -> Tuple[List[List[float]], List[List[float]]]:
"""
Get a tuple constaining coordinates for the two polygons outlining
Return a tuple containing coordinates for the two polygons outlining
the latency and execution time of the operation.
The polygons are corresponding to a start time of 0 and are of height 1.
"""
......@@ -398,7 +399,7 @@ class Operation(GraphComponent, SignalSourceProvider):
self,
) -> Tuple[List[List[float]], List[List[float]]]:
"""
Get a tuple constaining coordinates for inputs and outputs, respectively.
Return a tuple containing coordinates for inputs and outputs, respectively.
These maps to the polygons and are corresponding to a start time of 0
and height 1.
"""
......@@ -500,9 +501,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
for inp in self.inputs:
if inp.latency_offset is None:
inp.latency_offset = 0
for outp in self.outputs:
if outp.latency_offset is None:
outp.latency_offset = latency
for output in self.outputs:
if output.latency_offset is None:
output.latency_offset = latency
self._execution_time = execution_time
......@@ -592,13 +593,16 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def __rtruediv__(
self, src: Union[SignalSourceProvider, Number]
) -> "Division":
) -> Union["Division", "Reciprocal"]:
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
from b_asic.core_operations import Constant, Division, Reciprocal
return Division(
Constant(src) if isinstance(src, Number) else src, self
)
if isinstance(src, Number):
if src == 1:
return Reciprocal(self)
else:
return Division(Constant(src), self)
return Division(src, self)
def __lshift__(self, src: SignalSourceProvider) -> Signal:
if self.input_count != 1:
......@@ -835,10 +839,10 @@ class AbstractOperation(Operation, AbstractGraphComponent):
new_component: Operation = cast(
Operation, super().copy_component(*args, **kwargs)
)
for i, inp in enumerate(self.inputs):
new_component.input(i).latency_offset = inp.latency_offset
for i, outp in enumerate(self.outputs):
new_component.output(i).latency_offset = outp.latency_offset
for i, input in enumerate(self.inputs):
new_component.input(i).latency_offset = input.latency_offset
for i, output in enumerate(self.outputs):
new_component.output(i).latency_offset = output.latency_offset
new_component.execution_time = self._execution_time
return new_component
......@@ -930,7 +934,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
@property
def latency(self) -> int:
if None in [inp.latency_offset for inp in self.inputs] or None in [
outp.latency_offset for outp in self.outputs
output.latency_offset for output in self.outputs
]:
raise ValueError(
"All native offsets have to set to a non-negative value to"
......@@ -940,10 +944,10 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max(
(
(
cast(int, outp.latency_offset)
- cast(int, inp.latency_offset)
cast(int, output.latency_offset)
- cast(int, input.latency_offset)
)
for outp, inp in it.product(self.outputs, self.inputs)
for output, input in it.product(self.outputs, self.inputs)
)
)
......@@ -951,11 +955,11 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def latency_offsets(self) -> Dict[str, Optional[int]]:
latency_offsets = {}
for i, inp in enumerate(self.inputs):
latency_offsets[f"in{i}"] = inp.latency_offset
for i, input in enumerate(self.inputs):
latency_offsets[f"in{i}"] = input.latency_offset
for i, outp in enumerate(self.outputs):
latency_offsets[f"out{i}"] = outp.latency_offset
for i, output in enumerate(self.outputs):
latency_offsets[f"out{i}"] = output.latency_offset
return latency_offsets
......@@ -1072,18 +1076,18 @@ class AbstractOperation(Operation, AbstractGraphComponent):
) -> Tuple[List[List[float]], List[List[float]]]:
# Doc-string inherited
self._check_all_latencies_set()
input_coords = [
input_coordinates = [
[
self.inputs[k].latency_offset,
(1 + 2 * k) / (2 * len(self.inputs)),
]
for k in range(len(self.inputs))
]
output_coords = [
output_coordinates = [
[
self.outputs[k].latency_offset,
(1 + 2 * k) / (2 * len(self.outputs)),
]
for k in range(len(self.outputs))
]
return input_coords, output_coords
return input_coordinates, output_coordinates
......@@ -12,6 +12,7 @@ from b_asic import (
Max,
Min,
Multiplication,
Reciprocal,
SquareRoot,
Subtraction,
SymmetricTwoportAdaptor,
......@@ -261,6 +262,22 @@ class TestSymmetricTwoportAdaptor:
)
class TestReciprocal:
"""Tests for Absolute class."""
def test_reciprocal_positive(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [2]) == 0.5
def test_reciprocal_negative(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [-5]) == -0.2
def test_reciprocal_complex(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [1 + 1j]) == 0.5 - 0.5j
class TestDepends:
def test_depends_addition(self):
add1 = Addition()
......
......@@ -10,6 +10,7 @@ from b_asic import (
ConstantMultiplication,
Division,
Multiplication,
Reciprocal,
SquareRoot,
Subtraction,
)
......@@ -100,6 +101,10 @@ class TestOperationOverloading:
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
div4 = 1 / div3
assert isinstance(div4, Reciprocal)
assert div4.input(0).signals == div3.output(0).signals
class TestTraverse:
def test_traverse_single_tree(self, operation):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment