Skip to content
Snippets Groups Projects
Commit 9494fe85 authored by Arvid Westerlund's avatar Arvid Westerlund
Browse files

Resolve "Basic Operations"

parent 24120414
No related branches found
No related tags found
3 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0,!15Add changes from sprint 1 and 2 to master
......@@ -32,7 +32,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self._parameters = {}
@abstractmethod
def evaluate(self, inputs: list) -> list:
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a
list of input values."""
raise NotImplementedError
......
......@@ -4,15 +4,14 @@ TODO: More info.
"""
from numbers import Number
from typing import Any
from numpy import conjugate, sqrt, abs as np_abs
from b_asic.port import InputPort, OutputPort
from b_asic.operation import Operation
from b_asic.graph_id import GraphIDType
from b_asic.abstract_operation import AbstractOperation
from b_asic.abstract_graph_component import AbstractGraphComponent
from b_asic.graph_component import Name, TypeName
class Input(Operation, AbstractGraphComponent):
class Input(AbstractOperation):
"""Input operation.
TODO: More info.
"""
......@@ -24,6 +23,7 @@ class Input(Operation, AbstractGraphComponent):
return "in"
class Constant(AbstractOperation):
"""Constant value operation.
TODO: More info.
......@@ -32,15 +32,16 @@ class Constant(AbstractOperation):
def __init__(self, value: Number = 0, name: Name = ""):
super().__init__(name)
self._output_ports = [OutputPort(0, self)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(0, self)]
self._parameters["value"] = value
def evaluate(self, inputs: list) -> list:
return [self.param("value")]
def evaluate(self) -> Any:
return self.param("value")
@property
def type_name(self) -> TypeName:
return "const"
return "c"
class Addition(AbstractOperation):
......@@ -51,22 +52,207 @@ class Addition(AbstractOperation):
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(1, self)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(0, self)] # TODO: Generate appropriate ID for ports.
self._input_ports = [InputPort(0, self), InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
if source2 is not None:
self._input_ports[1].connect_to_port(source2)
def evaluate(self, inputs: list) -> list:
return [inputs[0] + inputs[1]]
def evaluate(self, a, b) -> Any:
return a + b
@property
def type_name(self) -> TypeName:
return "add"
class Subtraction(AbstractOperation):
"""Binary subtraction operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
if source2 is not None:
self._input_ports[1].connect_to_port(source2)
def evaluate(self, a, b) -> Any:
return a - b
@property
def type_name(self) -> GraphIDType:
return "sub"
class Multiplication(AbstractOperation):
"""Binary multiplication operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
if source2 is not None:
self._input_ports[1].connect_to_port(source2)
def evaluate(self, a, b) -> Any:
return a * b
@property
def type_name(self) -> GraphIDType:
return "mul"
class Division(AbstractOperation):
"""Binary division operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
if source2 is not None:
self._input_ports[1].connect_to_port(source2)
def evaluate(self, a, b) -> Any:
return a / b
@property
def type_name(self) -> GraphIDType:
return "div"
class SquareRoot(AbstractOperation):
"""Unary square root operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, a) -> Any:
return sqrt((complex)(a))
@property
def type_name(self) -> GraphIDType:
return "sqrt"
class ComplexConjugate(AbstractOperation):
"""Unary complex conjugate operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, a) -> Any:
return conjugate(a)
@property
def type_name(self) -> GraphIDType:
return "conj"
class Max(AbstractOperation):
"""Binary max operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
if source2 is not None:
self._input_ports[1].connect_to_port(source2)
def evaluate(self, a, b) -> Any:
assert not isinstance(a, complex) and not isinstance(b, complex), \
("core_operation.Max does not support complex numbers.")
return a if a > b else b
@property
def type_name(self) -> GraphIDType:
return "max"
class Min(AbstractOperation):
"""Binary min operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
if source2 is not None:
self._input_ports[1].connect_to_port(source2)
def evaluate(self, a, b) -> Any:
assert not isinstance(a, complex) and not isinstance(b, complex), \
("core_operation.Min does not support complex numbers.")
return a if a < b else b
@property
def type_name(self) -> GraphIDType:
return "min"
class Absolute(AbstractOperation):
"""Unary absolute value operation.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, a) -> Any:
return np_abs(a)
@property
def type_name(self) -> GraphIDType:
return "abs"
class ConstantMultiplication(AbstractOperation):
"""Unary constant multiplication operation.
TODO: More info.
......@@ -74,16 +260,82 @@ class ConstantMultiplication(AbstractOperation):
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(0, self)] # TODO: Generate appropriate ID for ports.
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, inputs: list) -> list:
return [inputs[0] * self.param("coefficient")]
def evaluate(self, a) -> Any:
return a * self.param("coefficient")
@property
def type_name(self) -> TypeName:
return "const_mul"
return "cmul"
class ConstantAddition(AbstractOperation):
"""Unary constant addition operation.
TODO: More info.
"""
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, a) -> Any:
return a + self.param("coefficient")
@property
def type_name(self) -> GraphIDType:
return "cadd"
class ConstantSubtraction(AbstractOperation):
"""Unary constant subtraction operation.
TODO: More info.
"""
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, a) -> Any:
return a - self.param("coefficient")
@property
def type_name(self) -> GraphIDType:
return "csub"
class ConstantDivision(AbstractOperation):
"""Unary constant division operation.
TODO: More info.
"""
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect_to_port(source1)
def evaluate(self, a) -> Any:
return a / self.param("coefficient")
@property
def type_name(self) -> GraphIDType:
return "cdiv"
"""
B-ASIC test suite for the basic operations.
"""
from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision
from b_asic.signal import Signal
import pytest
""" Constant tests. """
def test_constant():
constant_operation = Constant(3)
assert constant_operation.evaluate() == 3
def test_constant_negative():
constant_operation = Constant(-3)
assert constant_operation.evaluate() == -3
def test_constant_complex():
constant_operation = Constant(3+4j)
assert constant_operation.evaluate() == 3+4j
""" Addition tests. """
def test_addition():
test_operation = Addition()
constant_operation = Constant(3)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8
def test_addition_negative():
test_operation = Addition()
constant_operation = Constant(-3)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8
def test_addition_complex():
test_operation = Addition()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j)
""" Subtraction tests. """
def test_subtraction():
test_operation = Subtraction()
constant_operation = Constant(5)
constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2
def test_subtraction_negative():
test_operation = Subtraction()
constant_operation = Constant(-5)
constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2
def test_subtraction_complex():
test_operation = Subtraction()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j)
""" Multiplication tests. """
def test_multiplication():
test_operation = Multiplication()
constant_operation = Constant(5)
constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_negative():
test_operation = Multiplication()
constant_operation = Constant(-5)
constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_complex():
test_operation = Multiplication()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j)
""" Division tests. """
def test_division():
test_operation = Division()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_negative():
test_operation = Division()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_complex():
test_operation = Division()
constant_operation = Constant((60+40j))
constant_operation_2 = Constant((10+20j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j)
""" SquareRoot tests. """
def test_squareroot():
test_operation = SquareRoot()
constant_operation = Constant(36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6
def test_squareroot_negative():
test_operation = SquareRoot()
constant_operation = Constant(-36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6j
def test_squareroot_complex():
test_operation = SquareRoot()
constant_operation = Constant((48+64j))
assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j)
""" ComplexConjugate tests. """
def test_complexconjugate():
test_operation = ComplexConjugate()
constant_operation = Constant(3+4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j)
def test_test_complexconjugate_negative():
test_operation = ComplexConjugate()
constant_operation = Constant(-3-4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j)
""" Max tests. """
def test_max():
test_operation = Max()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30
def test_max_negative():
test_operation = Max()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5
""" Min tests. """
def test_min():
test_operation = Min()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5
def test_min_negative():
test_operation = Min()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30
""" Absolute tests. """
def test_absolute():
test_operation = Absolute()
constant_operation = Constant(30)
assert test_operation.evaluate(constant_operation.evaluate()) == 30
def test_absolute_negative():
test_operation = Absolute()
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == 5
def test_absolute_complex():
test_operation = Absolute()
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == 5.0
""" ConstantMultiplication tests. """
def test_constantmultiplication():
test_operation = ConstantMultiplication(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 100
def test_constantmultiplication_negative():
test_operation = ConstantMultiplication(5)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -25
def test_constantmultiplication_complex():
test_operation = ConstantMultiplication(3+2j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j)
""" ConstantAddition tests. """
def test_constantaddition():
test_operation = ConstantAddition(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 25
def test_constantaddition_negative():
test_operation = ConstantAddition(4)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -1
def test_constantaddition_complex():
test_operation = ConstantAddition(3+2j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j)
""" ConstantSubtraction tests. """
def test_constantsubtraction():
test_operation = ConstantSubtraction(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 15
def test_constantsubtraction_negative():
test_operation = ConstantSubtraction(4)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -9
def test_constantsubtraction_complex():
test_operation = ConstantSubtraction(4+6j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j)
""" ConstantDivision tests. """
def test_constantdivision():
test_operation = ConstantDivision(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 4
def test_constantdivision_negative():
test_operation = ConstantDivision(4)
constant_operation = Constant(-20)
assert test_operation.evaluate(constant_operation.evaluate()) == -5
def test_constantdivision_complex():
test_operation = ConstantDivision(2+2j)
constant_operation = Constant((10+10j))
assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment