Skip to content
Snippets Groups Projects
Commit d8d42645 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Move operator overloading to SignalSourceProvider

parent 6be246d5
No related branches found
No related tags found
1 merge request!410Move operator overloading to SignalSourceProvider
Pipeline #97847 passed
......@@ -31,15 +31,6 @@ from b_asic.signal import Signal
from b_asic.types import Num
if TYPE_CHECKING:
# Conditionally imported to avoid circular imports
from b_asic.core_operations import (
Addition,
ConstantMultiplication,
Division,
Multiplication,
Reciprocal,
Subtraction,
)
from b_asic.signal_flow_graph import SFG
......@@ -62,82 +53,6 @@ class Operation(GraphComponent, SignalSourceProvider):
Operations may specify how to quantize inputs through quantize_input().
"""
@abstractmethod
def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
"""
Overload the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
"""
Overload the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
"""
Overload the subtraction operator to make it return a new Subtraction
operation object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
"""
Overloads the subtraction operator to make it return a new Subtraction
operation object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __mul__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
"""
Overload the multiplication operator to make it return a new Multiplication
operation object that is connected to the self and other objects.
If *src* is a number, then returns a ConstantMultiplication operation object
instead.
"""
raise NotImplementedError
@abstractmethod
def __rmul__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
"""
Overload the multiplication operator to make it return a new Multiplication
operation object that is connected to the self and other objects.
If *src* is a number, then returns a ConstantMultiplication operation object
instead.
"""
raise NotImplementedError
@abstractmethod
def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
"""
Overload the division operator to make it return a new Division operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __rtruediv__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Division", "Reciprocal"]:
"""
Overload the division operator to make it return a new Division operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __lshift__(self, src: SignalSourceProvider) -> Signal:
"""
......@@ -626,76 +541,6 @@ class AbstractOperation(Operation, AbstractGraphComponent):
"""
raise NotImplementedError
def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
if isinstance(src, Number):
return Addition(self, Constant(src))
else:
return Addition(self, src)
def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
return Addition(Constant(src) if isinstance(src, Number) else src, self)
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
def __mul__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
if isinstance(src, Number)
else Multiplication(self, src)
)
def __rmul__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
if isinstance(src, Number)
else Multiplication(src, self)
)
def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(self, Constant(src) if isinstance(src, Number) else src)
def __rtruediv__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Division", "Reciprocal"]:
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division, Reciprocal
if isinstance(src, Number):
if src == 1:
return Reciprocal(self)
else:
return Division(Constant(src), self)
return Division(src, self)
def __lshift__(self, src: SignalSourceProvider) -> Signal:
if self.input_count != 1:
diff = "more" if self.input_count > 1 else "less"
......
......@@ -6,12 +6,22 @@ Contains classes for managing the ports of operations.
from abc import ABC, abstractmethod
from copy import copy
from typing import TYPE_CHECKING, List, Optional, Sequence
from numbers import Number
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from b_asic.graph_component import Name
from b_asic.signal import Signal
from b_asic.types import Num
if TYPE_CHECKING:
from b_asic.core_operations import (
Addition,
ConstantMultiplication,
Division,
Multiplication,
Reciprocal,
Subtraction,
)
from b_asic.operation import Operation
......@@ -167,6 +177,76 @@ class SignalSourceProvider(ABC):
"""Get the main source port provided by this object."""
raise NotImplementedError
def __add__(self, src: Union["SignalSourceProvider", Num]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
if isinstance(src, Number):
return Addition(self, Constant(src))
else:
return Addition(self, src)
def __radd__(self, src: Union["SignalSourceProvider", Num]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
return Addition(Constant(src) if isinstance(src, Number) else src, self)
def __sub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
def __rsub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
def __mul__(
self, src: Union["SignalSourceProvider", Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
if isinstance(src, Number)
else Multiplication(self, src)
)
def __rmul__(
self, src: Union["SignalSourceProvider", Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
if isinstance(src, Number)
else Multiplication(src, self)
)
def __truediv__(self, src: Union["SignalSourceProvider", Num]) -> "Division":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(self, Constant(src) if isinstance(src, Number) else src)
def __rtruediv__(
self, src: Union["SignalSourceProvider", Num]
) -> Union["Division", "Reciprocal"]:
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division, Reciprocal
if isinstance(src, Number):
if src == 1:
return Reciprocal(self)
else:
return Division(Constant(src), self)
return Division(src, self)
class InputPort(AbstractPort):
"""
......
......@@ -40,6 +40,12 @@ class TestOperationOverloading:
assert add5.input(0).signals[0].source.operation.value == 5
assert add5.input(1).signals == add4.output(0).signals
bfly = Butterfly()
add6 = bfly.output(0) + add5
assert isinstance(add6, Addition)
assert add6.input(0).signals == bfly.output(0).signals
assert add6.input(1).signals == add5.output(0).signals
def test_subtraction_overload(self):
"""Tests subtraction overloading for both operation and number argument."""
add1 = Addition(None, None, "add1")
......@@ -60,6 +66,12 @@ class TestOperationOverloading:
assert sub3.input(0).signals[0].source.operation.value == 5
assert sub3.input(1).signals == sub2.output(0).signals
bfly = Butterfly()
sub4 = bfly.output(0) - sub3
assert isinstance(sub4, Subtraction)
assert sub4.input(0).signals == bfly.output(0).signals
assert sub4.input(1).signals == sub3.output(0).signals
def test_multiplication_overload(self):
"""Tests multiplication overloading for both operation and number argument."""
add1 = Addition(None, None, "add1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment