Skip to content
Snippets Groups Projects
Commit 8ea03312 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Add is_linear and is_constant properties

parent b2e3d6f8
No related branches found
No related tags found
1 merge request!218Add is_linear and is_constant properties
Pipeline #90094 passed
......@@ -66,6 +66,20 @@ class Constant(AbstractOperation):
def latency(self) -> int:
return self.latency_offsets["out0"]
def __repr__(self) -> str:
return f"Constant({self.value})"
def __str__(self) -> str:
return f"{self.value}"
@property
def is_linear(self) -> bool:
return True
@property
def is_constant(self) -> bool:
return True
class Addition(AbstractOperation):
"""
......@@ -129,6 +143,10 @@ class Addition(AbstractOperation):
def evaluate(self, a, b):
return a + b
@property
def is_linear(self) -> bool:
return True
class Subtraction(AbstractOperation):
"""
......@@ -189,6 +207,10 @@ class Subtraction(AbstractOperation):
def evaluate(self, a, b):
return a - b
@property
def is_linear(self) -> bool:
return True
class AddSub(AbstractOperation):
r"""
......@@ -270,6 +292,10 @@ class AddSub(AbstractOperation):
"""Set if operation is an addition."""
self.set_param("is_add", is_add)
@property
def is_linear(self) -> bool:
return True
class Multiplication(AbstractOperation):
r"""
......@@ -331,6 +357,12 @@ class Multiplication(AbstractOperation):
def evaluate(self, a, b):
return a * b
@property
def is_linear(self) -> bool:
return any(
input.connected_source.operation.is_constant for input in self.inputs
)
class Division(AbstractOperation):
r"""
......@@ -372,6 +404,10 @@ class Division(AbstractOperation):
def evaluate(self, a, b):
return a / b
@property
def is_linear(self) -> bool:
return self.input(1).connected_source.operation.is_constant
class Min(AbstractOperation):
r"""
......@@ -618,6 +654,10 @@ class ConstantMultiplication(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def is_linear(self) -> bool:
return True
class Butterfly(AbstractOperation):
r"""
......@@ -660,6 +700,10 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
@property
def is_linear(self) -> bool:
return True
class MAD(AbstractOperation):
r"""
......@@ -699,6 +743,13 @@ class MAD(AbstractOperation):
def evaluate(self, a, b, c):
return a * b + c
@property
def is_linear(self) -> bool:
return (
self.input(0).connected_source.operation.is_constant
or self.input(1).connected_source.operation.is_constant
)
class SymmetricTwoportAdaptor(AbstractOperation):
r"""
......@@ -751,6 +802,10 @@ class SymmetricTwoportAdaptor(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def is_linear(self) -> bool:
return True
class Reciprocal(AbstractOperation):
r"""
......
......@@ -467,6 +467,22 @@ class Operation(GraphComponent, SignalSourceProvider):
def _check_all_latencies_set(self) -> None:
raise NotImplementedError
@property
@abstractmethod
def is_linear(self) -> bool:
"""
Return True if the operation is linear.
"""
raise NotImplementedError
@property
@abstractmethod
def is_constant(self) -> bool:
"""
Return True if the output of the operation is constant.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""
......@@ -1135,3 +1151,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
)
for k in range(num_out)
)
@property
def is_linear(self) -> bool:
if self.is_constant:
return True
return False
@property
def is_constant(self) -> bool:
return all(
input.connected_source.operation.is_constant for input in self.inputs
)
......@@ -1540,3 +1540,11 @@ class SFG(AbstractOperation):
assert len(ids) == len(set(ids))
return SFG(inputs=all_inputs, outputs=all_outputs)
@property
def is_linear(self) -> bool:
return all(op.is_linear for op in self.split())
@property
def is_constant(self) -> bool:
return all(output.is_constant for output in self._output_operations)
......@@ -89,6 +89,14 @@ class Input(AbstractOperation):
# doc-string inherited
return ((0, 0.5),)
@property
def is_constant(self) -> bool:
return False
@property
def is_linear(self) -> bool:
return True
class Output(AbstractOperation):
"""
......@@ -143,6 +151,10 @@ class Output(AbstractOperation):
def latency(self) -> int:
return self.latency_offsets["in0"]
@property
def is_linear(self) -> bool:
return True
class Delay(AbstractOperation):
"""
......@@ -221,3 +233,7 @@ class Delay(AbstractOperation):
def initial_value(self, value: Num) -> None:
"""Set the initial value of this delay."""
self.set_param("initial_value", value)
@property
def is_linear(self) -> bool:
return True
......@@ -1604,3 +1604,19 @@ class TestUnfold:
sfg = sfg_two_inputs_two_outputs
with pytest.raises(ValueError, match="Unfolding 0 times removes the SFG"):
sfg.unfold(0)
class TestIsLinear:
def test_single_accumulator(self, sfg_simple_accumulator: SFG):
assert sfg_simple_accumulator.is_linear
def test_sfg_nested(self, sfg_nested: SFG):
assert not sfg_nested.is_linear
class TestIsConstant:
def test_single_accumulator(self, sfg_simple_accumulator: SFG):
assert not sfg_simple_accumulator.is_constant
def test_sfg_nested(self, sfg_nested: SFG):
assert not sfg_nested.is_constant
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment