From d8d426459f7d26d14db117c38ab6220ea776850e Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Fri, 19 May 2023 21:42:22 +0200
Subject: [PATCH] Move operator overloading to SignalSourceProvider

---
 b_asic/operation.py    | 155 -----------------------------------------
 b_asic/port.py         |  82 +++++++++++++++++++++-
 test/test_operation.py |  12 ++++
 3 files changed, 93 insertions(+), 156 deletions(-)

diff --git a/b_asic/operation.py b/b_asic/operation.py
index 2d0c8972..4e6edd5f 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -31,15 +31,6 @@ from b_asic.signal import Signal
 from b_asic.types import Num
 
 if TYPE_CHECKING:
-    # Conditionally imported to avoid circular imports
-    from b_asic.core_operations import (
-        Addition,
-        ConstantMultiplication,
-        Division,
-        Multiplication,
-        Reciprocal,
-        Subtraction,
-    )
     from b_asic.signal_flow_graph import SFG
 
 
@@ -62,82 +53,6 @@ class Operation(GraphComponent, SignalSourceProvider):
     Operations may specify how to quantize inputs through quantize_input().
     """
 
-    @abstractmethod
-    def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
-        """
-        Overload the addition operator to make it return a new Addition operation
-        object that is connected to the self and other objects.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
-        """
-        Overload the addition operator to make it return a new Addition operation
-        object that is connected to the self and other objects.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
-        """
-        Overload the subtraction operator to make it return a new Subtraction
-        operation object that is connected to the self and other objects.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
-        """
-        Overloads the subtraction operator to make it return a new Subtraction
-        operation object that is connected to the self and other objects.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __mul__(
-        self, src: Union[SignalSourceProvider, Num]
-    ) -> Union["Multiplication", "ConstantMultiplication"]:
-        """
-        Overload the multiplication operator to make it return a new Multiplication
-        operation object that is connected to the self and other objects.
-
-        If *src* is a number, then returns a ConstantMultiplication operation object
-        instead.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __rmul__(
-        self, src: Union[SignalSourceProvider, Num]
-    ) -> Union["Multiplication", "ConstantMultiplication"]:
-        """
-        Overload the multiplication operator to make it return a new Multiplication
-        operation object that is connected to the self and other objects.
-
-        If *src* is a number, then returns a ConstantMultiplication operation object
-        instead.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
-        """
-        Overload the division operator to make it return a new Division operation
-        object that is connected to the self and other objects.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def __rtruediv__(
-        self, src: Union[SignalSourceProvider, Num]
-    ) -> Union["Division", "Reciprocal"]:
-        """
-        Overload the division operator to make it return a new Division operation
-        object that is connected to the self and other objects.
-        """
-        raise NotImplementedError
-
     @abstractmethod
     def __lshift__(self, src: SignalSourceProvider) -> Signal:
         """
@@ -626,76 +541,6 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         """
         raise NotImplementedError
 
-    def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import Addition, Constant
-
-        if isinstance(src, Number):
-            return Addition(self, Constant(src))
-        else:
-            return Addition(self, src)
-
-    def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import Addition, Constant
-
-        return Addition(Constant(src) if isinstance(src, Number) else src, self)
-
-    def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import Constant, Subtraction
-
-        return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
-
-    def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import Constant, Subtraction
-
-        return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
-
-    def __mul__(
-        self, src: Union[SignalSourceProvider, Num]
-    ) -> Union["Multiplication", "ConstantMultiplication"]:
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import ConstantMultiplication, Multiplication
-
-        return (
-            ConstantMultiplication(src, self)
-            if isinstance(src, Number)
-            else Multiplication(self, src)
-        )
-
-    def __rmul__(
-        self, src: Union[SignalSourceProvider, Num]
-    ) -> Union["Multiplication", "ConstantMultiplication"]:
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import ConstantMultiplication, Multiplication
-
-        return (
-            ConstantMultiplication(src, self)
-            if isinstance(src, Number)
-            else Multiplication(src, self)
-        )
-
-    def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import Constant, Division
-
-        return Division(self, Constant(src) if isinstance(src, Number) else src)
-
-    def __rtruediv__(
-        self, src: Union[SignalSourceProvider, Num]
-    ) -> Union["Division", "Reciprocal"]:
-        # Import here to avoid circular imports.
-        from b_asic.core_operations import Constant, Division, Reciprocal
-
-        if isinstance(src, Number):
-            if src == 1:
-                return Reciprocal(self)
-            else:
-                return Division(Constant(src), self)
-        return Division(src, self)
-
     def __lshift__(self, src: SignalSourceProvider) -> Signal:
         if self.input_count != 1:
             diff = "more" if self.input_count > 1 else "less"
diff --git a/b_asic/port.py b/b_asic/port.py
index 709a018b..0542f09e 100644
--- a/b_asic/port.py
+++ b/b_asic/port.py
@@ -6,12 +6,22 @@ Contains classes for managing the ports of operations.
 
 from abc import ABC, abstractmethod
 from copy import copy
-from typing import TYPE_CHECKING, List, Optional, Sequence
+from numbers import Number
+from typing import TYPE_CHECKING, List, Optional, Sequence, Union
 
 from b_asic.graph_component import Name
 from b_asic.signal import Signal
+from b_asic.types import Num
 
 if TYPE_CHECKING:
+    from b_asic.core_operations import (
+        Addition,
+        ConstantMultiplication,
+        Division,
+        Multiplication,
+        Reciprocal,
+        Subtraction,
+    )
     from b_asic.operation import Operation
 
 
@@ -167,6 +177,76 @@ class SignalSourceProvider(ABC):
         """Get the main source port provided by this object."""
         raise NotImplementedError
 
+    def __add__(self, src: Union["SignalSourceProvider", Num]) -> "Addition":
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import Addition, Constant
+
+        if isinstance(src, Number):
+            return Addition(self, Constant(src))
+        else:
+            return Addition(self, src)
+
+    def __radd__(self, src: Union["SignalSourceProvider", Num]) -> "Addition":
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import Addition, Constant
+
+        return Addition(Constant(src) if isinstance(src, Number) else src, self)
+
+    def __sub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction":
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import Constant, Subtraction
+
+        return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
+
+    def __rsub__(self, src: Union["SignalSourceProvider", Num]) -> "Subtraction":
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import Constant, Subtraction
+
+        return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
+
+    def __mul__(
+        self, src: Union["SignalSourceProvider", Num]
+    ) -> Union["Multiplication", "ConstantMultiplication"]:
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import ConstantMultiplication, Multiplication
+
+        return (
+            ConstantMultiplication(src, self)
+            if isinstance(src, Number)
+            else Multiplication(self, src)
+        )
+
+    def __rmul__(
+        self, src: Union["SignalSourceProvider", Num]
+    ) -> Union["Multiplication", "ConstantMultiplication"]:
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import ConstantMultiplication, Multiplication
+
+        return (
+            ConstantMultiplication(src, self)
+            if isinstance(src, Number)
+            else Multiplication(src, self)
+        )
+
+    def __truediv__(self, src: Union["SignalSourceProvider", Num]) -> "Division":
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import Constant, Division
+
+        return Division(self, Constant(src) if isinstance(src, Number) else src)
+
+    def __rtruediv__(
+        self, src: Union["SignalSourceProvider", Num]
+    ) -> Union["Division", "Reciprocal"]:
+        # Import here to avoid circular imports.
+        from b_asic.core_operations import Constant, Division, Reciprocal
+
+        if isinstance(src, Number):
+            if src == 1:
+                return Reciprocal(self)
+            else:
+                return Division(Constant(src), self)
+        return Division(src, self)
+
 
 class InputPort(AbstractPort):
     """
diff --git a/test/test_operation.py b/test/test_operation.py
index 6b16a148..69bfda46 100644
--- a/test/test_operation.py
+++ b/test/test_operation.py
@@ -40,6 +40,12 @@ class TestOperationOverloading:
         assert add5.input(0).signals[0].source.operation.value == 5
         assert add5.input(1).signals == add4.output(0).signals
 
+        bfly = Butterfly()
+        add6 = bfly.output(0) + add5
+        assert isinstance(add6, Addition)
+        assert add6.input(0).signals == bfly.output(0).signals
+        assert add6.input(1).signals == add5.output(0).signals
+
     def test_subtraction_overload(self):
         """Tests subtraction overloading for both operation and number argument."""
         add1 = Addition(None, None, "add1")
@@ -60,6 +66,12 @@ class TestOperationOverloading:
         assert sub3.input(0).signals[0].source.operation.value == 5
         assert sub3.input(1).signals == sub2.output(0).signals
 
+        bfly = Butterfly()
+        sub4 = bfly.output(0) - sub3
+        assert isinstance(sub4, Subtraction)
+        assert sub4.input(0).signals == bfly.output(0).signals
+        assert sub4.input(1).signals == sub3.output(0).signals
+
     def test_multiplication_overload(self):
         """Tests multiplication overloading for both operation and number argument."""
         add1 = Addition(None, None, "add1")
-- 
GitLab