diff --git a/b_asic/abstract_operation.py b/b_asic/abstract_operation.py index 1403f7a9bbff73c26f2667f99075159af33d3a0f..fc3a92051460a4d38dd843320c78a905119bf99e 100644 --- a/b_asic/abstract_operation.py +++ b/b_asic/abstract_operation.py @@ -32,7 +32,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): self._parameters = {} @abstractmethod - def evaluate(self, inputs: list) -> list: + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values.""" raise NotImplementedError diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 42867aa5a9d87546945e5a139fc1d04503fb67d9..f64c63db9da7c10e1dc95794e10e92233f0a9242 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -4,15 +4,14 @@ TODO: More info. """ from numbers import Number - +from typing import Any +from numpy import conjugate, sqrt, abs as np_abs from b_asic.port import InputPort, OutputPort -from b_asic.operation import Operation +from b_asic.graph_id import GraphIDType from b_asic.abstract_operation import AbstractOperation -from b_asic.abstract_graph_component import AbstractGraphComponent from b_asic.graph_component import Name, TypeName - -class Input(Operation, AbstractGraphComponent): +class Input(AbstractOperation): """Input operation. TODO: More info. """ @@ -24,6 +23,7 @@ class Input(Operation, AbstractGraphComponent): return "in" + class Constant(AbstractOperation): """Constant value operation. TODO: More info. @@ -32,15 +32,16 @@ class Constant(AbstractOperation): def __init__(self, value: Number = 0, name: Name = ""): super().__init__(name) - self._output_ports = [OutputPort(0, self)] # TODO: Generate appropriate ID for ports. + self._output_ports = [OutputPort(0, self)] self._parameters["value"] = value - def evaluate(self, inputs: list) -> list: - return [self.param("value")] + def evaluate(self) -> Any: + return self.param("value") @property def type_name(self) -> TypeName: - return "const" + return "c" + class Addition(AbstractOperation): @@ -51,22 +52,207 @@ class Addition(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(1, self)] # TODO: Generate appropriate ID for ports. - self._output_ports = [OutputPort(0, self)] # TODO: Generate appropriate ID for ports. + self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] if source1 is not None: self._input_ports[0].connect_to_port(source1) if source2 is not None: self._input_ports[1].connect_to_port(source2) - def evaluate(self, inputs: list) -> list: - return [inputs[0] + inputs[1]] + def evaluate(self, a, b) -> Any: + return a + b @property def type_name(self) -> TypeName: return "add" +class Subtraction(AbstractOperation): + """Binary subtraction operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + if source2 is not None: + self._input_ports[1].connect_to_port(source2) + + def evaluate(self, a, b) -> Any: + return a - b + + @property + def type_name(self) -> GraphIDType: + return "sub" + + +class Multiplication(AbstractOperation): + """Binary multiplication operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + if source2 is not None: + self._input_ports[1].connect_to_port(source2) + + def evaluate(self, a, b) -> Any: + return a * b + + @property + def type_name(self) -> GraphIDType: + return "mul" + + +class Division(AbstractOperation): + """Binary division operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + if source2 is not None: + self._input_ports[1].connect_to_port(source2) + + def evaluate(self, a, b) -> Any: + return a / b + + @property + def type_name(self) -> GraphIDType: + return "div" + + +class SquareRoot(AbstractOperation): + """Unary square root operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + + + def evaluate(self, a) -> Any: + return sqrt((complex)(a)) + + @property + def type_name(self) -> GraphIDType: + return "sqrt" + + +class ComplexConjugate(AbstractOperation): + """Unary complex conjugate operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + + + def evaluate(self, a) -> Any: + return conjugate(a) + + @property + def type_name(self) -> GraphIDType: + return "conj" + + +class Max(AbstractOperation): + """Binary max operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + if source2 is not None: + self._input_ports[1].connect_to_port(source2) + + def evaluate(self, a, b) -> Any: + assert not isinstance(a, complex) and not isinstance(b, complex), \ + ("core_operation.Max does not support complex numbers.") + return a if a > b else b + + @property + def type_name(self) -> GraphIDType: + return "max" + + +class Min(AbstractOperation): + """Binary min operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + if source2 is not None: + self._input_ports[1].connect_to_port(source2) + + def evaluate(self, a, b) -> Any: + assert not isinstance(a, complex) and not isinstance(b, complex), \ + ("core_operation.Min does not support complex numbers.") + return a if a < b else b + + @property + def type_name(self) -> GraphIDType: + return "min" + + +class Absolute(AbstractOperation): + """Unary absolute value operation. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + + + def evaluate(self, a) -> Any: + return np_abs(a) + + @property + def type_name(self) -> GraphIDType: + return "abs" + + class ConstantMultiplication(AbstractOperation): """Unary constant multiplication operation. TODO: More info. @@ -74,16 +260,82 @@ class ConstantMultiplication(AbstractOperation): def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self)] # TODO: Generate appropriate ID for ports. - self._output_ports = [OutputPort(0, self)] # TODO: Generate appropriate ID for ports. + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] self._parameters["coefficient"] = coefficient if source1 is not None: self._input_ports[0].connect_to_port(source1) - def evaluate(self, inputs: list) -> list: - return [inputs[0] * self.param("coefficient")] + def evaluate(self, a) -> Any: + return a * self.param("coefficient") @property def type_name(self) -> TypeName: - return "const_mul" + return "cmul" + + +class ConstantAddition(AbstractOperation): + """Unary constant addition operation. + TODO: More info. + """ + + def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + self._parameters["coefficient"] = coefficient + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + + def evaluate(self, a) -> Any: + return a + self.param("coefficient") + + @property + def type_name(self) -> GraphIDType: + return "cadd" + + +class ConstantSubtraction(AbstractOperation): + """Unary constant subtraction operation. + TODO: More info. + """ + + def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + self._parameters["coefficient"] = coefficient + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + + def evaluate(self, a) -> Any: + return a - self.param("coefficient") + + @property + def type_name(self) -> GraphIDType: + return "csub" + + +class ConstantDivision(AbstractOperation): + """Unary constant division operation. + TODO: More info. + """ + + def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self)] + self._output_ports = [OutputPort(0, self)] + self._parameters["coefficient"] = coefficient + + if source1 is not None: + self._input_ports[0].connect_to_port(source1) + + def evaluate(self, a) -> Any: + return a / self.param("coefficient") + + @property + def type_name(self) -> GraphIDType: + return "cdiv" diff --git a/test/basic_operations/test_basic_operations.py b/test/basic_operations/test_basic_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..215610749432e8104834fabd57e06e31f4d00fd0 --- /dev/null +++ b/test/basic_operations/test_basic_operations.py @@ -0,0 +1,229 @@ +""" +B-ASIC test suite for the basic operations. +""" + +from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision +from b_asic.signal import Signal +import pytest + +""" Constant tests. """ +def test_constant(): + constant_operation = Constant(3) + assert constant_operation.evaluate() == 3 + +def test_constant_negative(): + constant_operation = Constant(-3) + assert constant_operation.evaluate() == -3 + +def test_constant_complex(): + constant_operation = Constant(3+4j) + assert constant_operation.evaluate() == 3+4j + +""" Addition tests. """ +def test_addition(): + test_operation = Addition() + constant_operation = Constant(3) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 + +def test_addition_negative(): + test_operation = Addition() + constant_operation = Constant(-3) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 + +def test_addition_complex(): + test_operation = Addition() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) + +""" Subtraction tests. """ +def test_subtraction(): + test_operation = Subtraction() + constant_operation = Constant(5) + constant_operation_2 = Constant(3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 + +def test_subtraction_negative(): + test_operation = Subtraction() + constant_operation = Constant(-5) + constant_operation_2 = Constant(-3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 + +def test_subtraction_complex(): + test_operation = Subtraction() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) + +""" Multiplication tests. """ +def test_multiplication(): + test_operation = Multiplication() + constant_operation = Constant(5) + constant_operation_2 = Constant(3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + +def test_multiplication_negative(): + test_operation = Multiplication() + constant_operation = Constant(-5) + constant_operation_2 = Constant(-3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + +def test_multiplication_complex(): + test_operation = Multiplication() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) + +""" Division tests. """ +def test_division(): + test_operation = Division() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + +def test_division_negative(): + test_operation = Division() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + +def test_division_complex(): + test_operation = Division() + constant_operation = Constant((60+40j)) + constant_operation_2 = Constant((10+20j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) + +""" SquareRoot tests. """ +def test_squareroot(): + test_operation = SquareRoot() + constant_operation = Constant(36) + assert test_operation.evaluate(constant_operation.evaluate()) == 6 + +def test_squareroot_negative(): + test_operation = SquareRoot() + constant_operation = Constant(-36) + assert test_operation.evaluate(constant_operation.evaluate()) == 6j + +def test_squareroot_complex(): + test_operation = SquareRoot() + constant_operation = Constant((48+64j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) + +""" ComplexConjugate tests. """ +def test_complexconjugate(): + test_operation = ComplexConjugate() + constant_operation = Constant(3+4j) + assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) + +def test_test_complexconjugate_negative(): + test_operation = ComplexConjugate() + constant_operation = Constant(-3-4j) + assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) + +""" Max tests. """ +def test_max(): + test_operation = Max() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 + +def test_max_negative(): + test_operation = Max() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 + +""" Min tests. """ +def test_min(): + test_operation = Min() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 + +def test_min_negative(): + test_operation = Min() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 + +""" Absolute tests. """ +def test_absolute(): + test_operation = Absolute() + constant_operation = Constant(30) + assert test_operation.evaluate(constant_operation.evaluate()) == 30 + +def test_absolute_negative(): + test_operation = Absolute() + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == 5 + +def test_absolute_complex(): + test_operation = Absolute() + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 + +""" ConstantMultiplication tests. """ +def test_constantmultiplication(): + test_operation = ConstantMultiplication(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 100 + +def test_constantmultiplication_negative(): + test_operation = ConstantMultiplication(5) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -25 + +def test_constantmultiplication_complex(): + test_operation = ConstantMultiplication(3+2j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) + +""" ConstantAddition tests. """ +def test_constantaddition(): + test_operation = ConstantAddition(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 25 + +def test_constantaddition_negative(): + test_operation = ConstantAddition(4) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -1 + +def test_constantaddition_complex(): + test_operation = ConstantAddition(3+2j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) + +""" ConstantSubtraction tests. """ +def test_constantsubtraction(): + test_operation = ConstantSubtraction(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 15 + +def test_constantsubtraction_negative(): + test_operation = ConstantSubtraction(4) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -9 + +def test_constantsubtraction_complex(): + test_operation = ConstantSubtraction(4+6j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) + +""" ConstantDivision tests. """ +def test_constantdivision(): + test_operation = ConstantDivision(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 4 + +def test_constantdivision_negative(): + test_operation = ConstantDivision(4) + constant_operation = Constant(-20) + assert test_operation.evaluate(constant_operation.evaluate()) == -5 + +def test_constantdivision_complex(): + test_operation = ConstantDivision(2+2j) + constant_operation = Constant((10+10j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) \ No newline at end of file