diff --git a/b_asic/operation.py b/b_asic/operation.py index 2d0c8972a8b33c43aaf209190c51f70766e2bf38..4e6edd5f4207e06d57697f93719bec6884f32e8c 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -31,15 +31,6 @@ from b_asic.signal import Signal from b_asic.types import Num if TYPE_CHECKING: - # Conditionally imported to avoid circular imports - from b_asic.core_operations import ( - Addition, - ConstantMultiplication, - Division, - Multiplication, - Reciprocal, - Subtraction, - ) from b_asic.signal_flow_graph import SFG @@ -62,82 +53,6 @@ class Operation(GraphComponent, SignalSourceProvider): Operations may specify how to quantize inputs through quantize_input(). """ - @abstractmethod - def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - """ - Overload the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - """ - Overload the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - """ - Overload the subtraction operator to make it return a new Subtraction - operation object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - """ - Overloads the subtraction operator to make it return a new Subtraction - operation object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __mul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - """ - Overload the multiplication operator to make it return a new Multiplication - operation object that is connected to the self and other objects. - - If *src* is a number, then returns a ConstantMultiplication operation object - instead. - """ - raise NotImplementedError - - @abstractmethod - def __rmul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - """ - Overload the multiplication operator to make it return a new Multiplication - operation object that is connected to the self and other objects. - - If *src* is a number, then returns a ConstantMultiplication operation object - instead. - """ - raise NotImplementedError - - @abstractmethod - def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division": - """ - Overload the division operator to make it return a new Division operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - - @abstractmethod - def __rtruediv__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Division", "Reciprocal"]: - """ - Overload the division operator to make it return a new Division operation - object that is connected to the self and other objects. - """ - raise NotImplementedError - @abstractmethod def __lshift__(self, src: SignalSourceProvider) -> Signal: """ @@ -626,76 +541,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): """ raise NotImplementedError - def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, Constant - - if isinstance(src, Number): - return Addition(self, Constant(src)) - else: - return Addition(self, src) - - def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, Constant - - return Addition(Constant(src) if isinstance(src, Number) else src, self) - - def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Subtraction - - return Subtraction(self, Constant(src) if isinstance(src, Number) else src) - - def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Subtraction - - return Subtraction(Constant(src) if isinstance(src, Number) else src, self) - - def __mul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - # Import here to avoid circular imports. - from b_asic.core_operations import ConstantMultiplication, Multiplication - - return ( - ConstantMultiplication(src, self) - if isinstance(src, Number) - else Multiplication(self, src) - ) - - def __rmul__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Multiplication", "ConstantMultiplication"]: - # Import here to avoid circular imports. - from b_asic.core_operations import ConstantMultiplication, Multiplication - - return ( - ConstantMultiplication(src, self) - if isinstance(src, Number) - else Multiplication(src, self) - ) - - def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division": - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Division - - return Division(self, Constant(src) if isinstance(src, Number) else src) - - def __rtruediv__( - self, src: Union[SignalSourceProvider, Num] - ) -> Union["Division", "Reciprocal"]: - # Import here to avoid circular imports. - from b_asic.core_operations import Constant, Division, Reciprocal - - if isinstance(src, Number): - if src == 1: - return Reciprocal(self) - else: - return Division(Constant(src), self) - return Division(src, self) - def __lshift__(self, src: SignalSourceProvider) -> Signal: if self.input_count != 1: diff = "more" if self.input_count > 1 else "less" diff --git a/b_asic/port.py b/b_asic/port.py index 709a018ba542e37c265942e137a296a6bdf2c003..0542f09ef101a56f62f40dd8a5ae92a3e209018d 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -6,12 +6,22 @@ Contains classes for managing the ports of operations. from abc import ABC, abstractmethod from copy import copy -from typing import TYPE_CHECKING, List, Optional, Sequence +from numbers import Number +from typing import TYPE_CHECKING, List, Optional, Sequence, Union from b_asic.graph_component import Name from b_asic.signal import Signal +from b_asic.types import Num if TYPE_CHECKING: + from b_asic.core_operations import ( + Addition, + ConstantMultiplication, + Division, + Multiplication, + Reciprocal, + Subtraction, + ) from b_asic.operation import Operation @@ -167,6 +177,76 @@ class SignalSourceProvider(ABC): """Get the main source port provided by this object.""" raise NotImplementedError + def __add__(self, src: Union["SignalSourceProvider", Num]) -> "Addition": + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, Constant + + if isinstance(src, Number): + return Addition(self, Constant(src)) + else: + return Addition(self, src) + + def __radd__(self, src: Union["SignalSourceProvider", Num]) -> "Addition": + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, Constant + + return Addition(Constant(src) if isinstance(src, Number) else src, self) + + def __sub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction": + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Subtraction + + return Subtraction(self, Constant(src) if isinstance(src, Number) else src) + + def __rsub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction": + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Subtraction + + return Subtraction(Constant(src) if isinstance(src, Number) else src, self) + + def __mul__( + self, src: Union["SignalSourceProvider", Num] + ) -> Union["Multiplication", "ConstantMultiplication"]: + # Import here to avoid circular imports. + from b_asic.core_operations import ConstantMultiplication, Multiplication + + return ( + ConstantMultiplication(src, self) + if isinstance(src, Number) + else Multiplication(self, src) + ) + + def __rmul__( + self, src: Union["SignalSourceProvider", Num] + ) -> Union["Multiplication", "ConstantMultiplication"]: + # Import here to avoid circular imports. + from b_asic.core_operations import ConstantMultiplication, Multiplication + + return ( + ConstantMultiplication(src, self) + if isinstance(src, Number) + else Multiplication(src, self) + ) + + def __truediv__(self, src: Union["SignalSourceProvider", Num]) -> "Division": + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Division + + return Division(self, Constant(src) if isinstance(src, Number) else src) + + def __rtruediv__( + self, src: Union["SignalSourceProvider", Num] + ) -> Union["Division", "Reciprocal"]: + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Division, Reciprocal + + if isinstance(src, Number): + if src == 1: + return Reciprocal(self) + else: + return Division(Constant(src), self) + return Division(src, self) + class InputPort(AbstractPort): """ diff --git a/test/test_operation.py b/test/test_operation.py index 6b16a148405a3ffc32401b3e95c3184937926661..69bfda46bf4187e40d278ab91dc84cbfee050b94 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -40,6 +40,12 @@ class TestOperationOverloading: assert add5.input(0).signals[0].source.operation.value == 5 assert add5.input(1).signals == add4.output(0).signals + bfly = Butterfly() + add6 = bfly.output(0) + add5 + assert isinstance(add6, Addition) + assert add6.input(0).signals == bfly.output(0).signals + assert add6.input(1).signals == add5.output(0).signals + def test_subtraction_overload(self): """Tests subtraction overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") @@ -60,6 +66,12 @@ class TestOperationOverloading: assert sub3.input(0).signals[0].source.operation.value == 5 assert sub3.input(1).signals == sub2.output(0).signals + bfly = Butterfly() + sub4 = bfly.output(0) - sub3 + assert isinstance(sub4, Subtraction) + assert sub4.input(0).signals == bfly.output(0).signals + assert sub4.input(1).signals == sub3.output(0).signals + def test_multiplication_overload(self): """Tests multiplication overloading for both operation and number argument.""" add1 = Addition(None, None, "add1")