diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index d06cbab3dc5356465bdc3c2f2b93c55f26715aed..8902b169c07e600a843e526d44452fb9386ff7a9 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -55,9 +55,9 @@ class Addition(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a + b @@ -78,9 +78,9 @@ class Subtraction(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a - b @@ -101,9 +101,9 @@ class Multiplication(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a * b @@ -124,9 +124,9 @@ class Division(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a / b @@ -147,8 +147,7 @@ class SquareRoot(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) - + self._input_ports[0].connect(source1) def evaluate(self, a): return sqrt((complex)(a)) @@ -169,8 +168,7 @@ class ComplexConjugate(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) - + self._input_ports[0].connect(source1) def evaluate(self, a): return conjugate(a) @@ -191,9 +189,9 @@ class Max(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): assert not isinstance(a, complex) and not isinstance(b, complex), \ @@ -216,9 +214,9 @@ class Min(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): assert not isinstance(a, complex) and not isinstance(b, complex), \ @@ -241,8 +239,7 @@ class Absolute(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) - + self._input_ports[0].connect(source1) def evaluate(self, a): return np_abs(a) @@ -264,7 +261,7 @@ class ConstantMultiplication(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a * self.param("coefficient") @@ -286,7 +283,7 @@ class ConstantAddition(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a + self.param("coefficient") @@ -308,7 +305,7 @@ class ConstantSubtraction(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a - self.param("coefficient") @@ -330,7 +327,7 @@ class ConstantDivision(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a / self.param("coefficient") diff --git a/b_asic/operation.py b/b_asic/operation.py index fc007ffde91e596fc103e811677d7aab96241fa4..5578e3c48edcf15594d6d1cd71e71a17521eca25 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -109,9 +109,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): self._parameters = {} @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a - list of input values.""" + list of input values. + """ raise NotImplementedError def inputs(self) -> List["InputPort"]: @@ -139,15 +140,15 @@ class AbstractOperation(Operation, AbstractGraphComponent): return self._parameters.get(name) def set_param(self, name: str, value: Any) -> None: - assert name in self._parameters # TODO: Error message. + assert name in self._parameters # TODO: Error message. self._parameters[name] = value def evaluate_outputs(self, state: SimulationState) -> List[Number]: # TODO: Check implementation. input_count: int = self.input_count() output_count: int = self.output_count() - assert input_count == len(self._input_ports) # TODO: Error message. - assert output_count == len(self._output_ports) # TODO: Error message. + assert input_count == len(self._input_ports) # TODO: Error message. + assert output_count == len(self._output_ports) # TODO: Error message. self_state: OperationState = state.operation_states[self] @@ -155,10 +156,12 @@ class AbstractOperation(Operation, AbstractGraphComponent): input_values: List[Number] = [0] * input_count for i in range(input_count): source: Signal = self._input_ports[i].signal - input_values[i] = source.operation.evaluate_outputs(state)[source.port_index] + input_values[i] = source.operation.evaluate_outputs(state)[ + source.port_index] self_state.output_values = self.evaluate(input_values) - assert len(self_state.output_values) == output_count # TODO: Error message. + # TODO: Error message. + assert len(self_state.output_values) == output_count self_state.iteration += 1 for i in range(output_count): for signal in self._output_ports[i].signals(): @@ -202,3 +205,64 @@ class AbstractOperation(Operation, AbstractGraphComponent): if n_operation not in visited: visited.add(n_operation) queue.append(n_operation) + + def __add__(self, other): + """Overloads the addition operator to make it return a new Addition operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantAddition operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, ConstantAddition + + if isinstance(other, Operation): + return Addition(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantAddition(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + + def __sub__(self, other): + """Overloads the subtraction operator to make it return a new Subtraction operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantSubtraction operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Subtraction, ConstantSubtraction + + if isinstance(other, Operation): + return Subtraction(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantSubtraction(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + + def __mul__(self, other): + """Overloads the multiplication operator to make it return a new Multiplication operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantMultiplication operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Multiplication, ConstantMultiplication + + if isinstance(other, Operation): + return Multiplication(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantMultiplication(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + + def __truediv__(self, other): + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantDivision operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Division, ConstantDivision + + if isinstance(other, Operation): + return Division(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantDivision(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + diff --git a/test/operation/test_abstract_operation.py b/test/operation/test_abstract_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..626a2dc3e5e26fb76d9266dcdd31940681df5c6e --- /dev/null +++ b/test/operation/test_abstract_operation.py @@ -0,0 +1,77 @@ +""" +B-ASIC test suite for the AbstractOperation class. +""" + +from b_asic.core_operations import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ + Multiplication, ConstantMultiplication, Division, ConstantDivision + +import pytest + + +def test_addition_overload(): + """Tests addition overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + add3 = add1 + add2 + + assert isinstance(add3, Addition) + assert add3.input(0).signals == add1.output(0).signals + assert add3.input(1).signals == add2.output(0).signals + + add4 = add3 + 5 + + assert isinstance(add4, ConstantAddition) + assert add4.input(0).signals == add3.output(0).signals + + +def test_subtraction_overload(): + """Tests subtraction overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + sub1 = add1 - add2 + + assert isinstance(sub1, Subtraction) + assert sub1.input(0).signals == add1.output(0).signals + assert sub1.input(1).signals == add2.output(0).signals + + sub2 = sub1 - 5 + + assert isinstance(sub2, ConstantSubtraction) + assert sub2.input(0).signals == sub1.output(0).signals + + +def test_multiplication_overload(): + """Tests multiplication overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + mul1 = add1 * add2 + + assert isinstance(mul1, Multiplication) + assert mul1.input(0).signals == add1.output(0).signals + assert mul1.input(1).signals == add2.output(0).signals + + mul2 = mul1 * 5 + + assert isinstance(mul2, ConstantMultiplication) + assert mul2.input(0).signals == mul1.output(0).signals + + +def test_division_overload(): + """Tests division overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + div1 = add1 / add2 + + assert isinstance(div1, Division) + assert div1.input(0).signals == add1.output(0).signals + assert div1.input(1).signals == add2.output(0).signals + + div2 = div1 / 5 + + assert isinstance(div2, ConstantDivision) + assert div2.input(0).signals == div1.output(0).signals +