From d8d426459f7d26d14db117c38ab6220ea776850e Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Fri, 19 May 2023 21:42:22 +0200 Subject: [PATCH] Move operator overloading to SignalSourceProvider --- b_asic/operation.py | 155 ----------------------------------------- b_asic/port.py | 82 +++++++++++++++++++++- test/test_operation.py | 12 ++++ 3 files changed, 93 insertions(+), 156 deletions(-) diff --git a/b_asic/operation.py b/b_asic/operation.py index 2d0c8972..4e6edd5f 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -31,15 +31,6 @@ from b_asic.signal import Signal from b_asic.types import Num if TYPE_CHECKING: - # Conditionally imported to avoid circular imports - from b_asic.core_operations import ( - Addition, - ConstantMultiplication, - Division, - Multiplication, - Reciprocal, - Subtraction, - ) from b_asic.signal_flow_graph import SFG @@ -62,82 +53,6 @@ class Operation(GraphComponent, SignalSourceProvider): Operations may specify how to quantize inputs through quantize_input(). """ - @abstractmethod - def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - """ - Overload the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - """ - Overload the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - """ - Overload the subtraction operator to make it return a new Subtraction - operation object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - """ - Overloads the subtraction operator to make it return a new Subtraction - operation object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __mul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - """ - Overload the multiplication operator to make it return a new Multiplication - operation object that is connected to the self and other objects. - - If *src* is a number, then returns a ConstantMultiplication operation object - instead. - """ - raise NotImplementedError - - @abstractmethod - def __rmul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - """ - Overload the multiplication operator to make it return a new Multiplication - operation object that is connected to the self and other objects. - - If *src* is a number, then returns a ConstantMultiplication operation object - instead. - """ - raise NotImplementedError - - @abstractmethod - def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division": - """ - Overload the division operator to make it return a new Division operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __rtruediv__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Division", "Reciprocal"]: - """ - Overload the division operator to make it return a new Division operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - @abstractmethod def __lshift__(self, src: SignalSourceProvider) -> Signal: """ @@ -626,76 +541,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): """ raise NotImplementedError - def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, Constant - - if isinstance(src, Number): - return Addition(self, Constant(src)) - else: - return Addition(self, src) - - def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, Constant - - return Addition(Constant(src) if isinstance(src, Number) else src, self) - - def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Subtraction - - return Subtraction(self, Constant(src) if isinstance(src, Number) else src) - - def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Subtraction - - return Subtraction(Constant(src) if isinstance(src, Number) else src, self) - - def __mul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - # Import here to avoid circular imports. - from b_asic.core_operations import ConstantMultiplication, Multiplication - - return ( - ConstantMultiplication(src, self) - if isinstance(src, Number) - else Multiplication(self, src) - ) - - def __rmul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - # Import here to avoid circular imports. - from b_asic.core_operations import ConstantMultiplication, Multiplication - - return ( - ConstantMultiplication(src, self) - if isinstance(src, Number) - else Multiplication(src, self) - ) - - def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division": - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Division - - return Division(self, Constant(src) if isinstance(src, Number) else src) - - def __rtruediv__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Division", "Reciprocal"]: - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Division, Reciprocal - - if isinstance(src, Number): - if src == 1: - return Reciprocal(self) - else: - return Division(Constant(src), self) - return Division(src, self) - def __lshift__(self, src: SignalSourceProvider) -> Signal: if self.input_count != 1: diff = "more" if self.input_count > 1 else "less" diff --git a/b_asic/port.py b/b_asic/port.py index 709a018b..0542f09e 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -6,12 +6,22 @@ Contains classes for managing the ports of operations. from abc import ABC, abstractmethod from copy import copy -from typing import TYPE_CHECKING, List, Optional, Sequence +from numbers import Number +from typing import TYPE_CHECKING, List, Optional, Sequence, Union from b_asic.graph_component import Name from b_asic.signal import Signal +from b_asic.types import Num if TYPE_CHECKING: + from b_asic.core_operations import ( + Addition, + ConstantMultiplication, + Division, + Multiplication, + Reciprocal, + Subtraction, + ) from b_asic.operation import Operation @@ -167,6 +177,76 @@ class SignalSourceProvider(ABC): """Get the main source port provided by this object.""" raise NotImplementedError + def __add__(self, src: Union["SignalSourceProvider", Num]) -> "Addition": + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, Constant + + if isinstance(src, Number): + return Addition(self, Constant(src)) + else: + return Addition(self, src) + + def __radd__(self, src: Union["SignalSourceProvider", Num]) -> "Addition": + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, Constant + + return Addition(Constant(src) if isinstance(src, Number) else src, self) + + def __sub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction": + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Subtraction + + return Subtraction(self, Constant(src) if isinstance(src, Number) else src) + + def __rsub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction": + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Subtraction + + return Subtraction(Constant(src) if isinstance(src, Number) else src, self) + + def __mul__( + self, src: Union["SignalSourceProvider", Num] + ) -> Union["Multiplication", "ConstantMultiplication"]: + # Import here to avoid circular imports. + from b_asic.core_operations import ConstantMultiplication, Multiplication + + return ( + ConstantMultiplication(src, self) + if isinstance(src, Number) + else Multiplication(self, src) + ) + + def __rmul__( + self, src: Union["SignalSourceProvider", Num] + ) -> Union["Multiplication", "ConstantMultiplication"]: + # Import here to avoid circular imports. + from b_asic.core_operations import ConstantMultiplication, Multiplication + + return ( + ConstantMultiplication(src, self) + if isinstance(src, Number) + else Multiplication(src, self) + ) + + def __truediv__(self, src: Union["SignalSourceProvider", Num]) -> "Division": + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Division + + return Division(self, Constant(src) if isinstance(src, Number) else src) + + def __rtruediv__( + self, src: Union["SignalSourceProvider", Num] + ) -> Union["Division", "Reciprocal"]: + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Division, Reciprocal + + if isinstance(src, Number): + if src == 1: + return Reciprocal(self) + else: + return Division(Constant(src), self) + return Division(src, self) + class InputPort(AbstractPort): """ diff --git a/test/test_operation.py b/test/test_operation.py index 6b16a148..69bfda46 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -40,6 +40,12 @@ class TestOperationOverloading: assert add5.input(0).signals[0].source.operation.value == 5 assert add5.input(1).signals == add4.output(0).signals + bfly = Butterfly() + add6 = bfly.output(0) + add5 + assert isinstance(add6, Addition) + assert add6.input(0).signals == bfly.output(0).signals + assert add6.input(1).signals == add5.output(0).signals + def test_subtraction_overload(self): """Tests subtraction overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") @@ -60,6 +66,12 @@ class TestOperationOverloading: assert sub3.input(0).signals[0].source.operation.value == 5 assert sub3.input(1).signals == sub2.output(0).signals + bfly = Butterfly() + sub4 = bfly.output(0) - sub3 + assert isinstance(sub4, Subtraction) + assert sub4.input(0).signals == bfly.output(0).signals + assert sub4.input(1).signals == sub3.output(0).signals + def test_multiplication_overload(self): """Tests multiplication overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") -- GitLab