Skip to content
Snippets Groups Projects
Commit 37d0425a authored by Angus Lothian's avatar Angus Lothian :dark_sunglasses: Committed by Ivar Härnqvist
Browse files

Change test of multiple outputs of evaluate output and Butterfly to not depend...

Change test of multiple outputs of evaluate output and Butterfly to not depend on implementation returing list or tuple
parent 7e2d5182
No related branches found
No related tags found
2 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0
......@@ -4,10 +4,8 @@ TODO: More info.
"""
from numbers import Number
from typing import Any
from numpy import conjugate, sqrt, abs as np_abs
from b_asic.port import InputPort, OutputPort
from b_asic.graph_id import GraphIDType
from b_asic.operation import AbstractOperation
from b_asic.graph_component import Name, TypeName
......@@ -335,3 +333,28 @@ class ConstantDivision(AbstractOperation):
@property
def type_name(self) -> TypeName:
return "cdiv"
class Butterfly(AbstractOperation):
"""Butterfly operation that returns two outputs.
The first output is a + b and the second output is a - b.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self), OutputPort(1, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
if source2 is not None:
self._input_ports[1].connect(source2)
def evaluate(self, a, b):
return a + b, a - b
@property
def type_name(self) -> TypeName:
return "bfly"
......@@ -5,12 +5,10 @@ TODO: More info.
from abc import abstractmethod
from numbers import Number
from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING
from typing import List, Dict, Optional, Any, Set, Sequence, TYPE_CHECKING
from collections import deque
from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.simulation import SimulationState, OperationState
from b_asic.signal import Signal
if TYPE_CHECKING:
from b_asic.port import InputPort, OutputPort
......@@ -51,6 +49,12 @@ class Operation(GraphComponent):
"""Get the output port at index i."""
raise NotImplementedError
@abstractmethod
def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]:
"""Evaluate the output port at the entered index with the entered input values and
returns all output values that are calulated during the evaluation in a list."""
raise NotImplementedError
@abstractmethod
def params(self) -> Dict[str, Optional[Any]]:
"""Get a dictionary of all parameter values."""
......@@ -70,13 +74,6 @@ class Operation(GraphComponent):
"""
raise NotImplementedError
@abstractmethod
def evaluate_outputs(self, state: "SimulationState") -> List[Number]:
"""Simulate the circuit until its iteration count matches that of the simulation state,
then return the resulting output vector.
"""
raise NotImplementedError
@abstractmethod
def split(self) -> "List[Operation]":
"""Split the operation into multiple operations.
......@@ -115,6 +112,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
"""
raise NotImplementedError
def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]:
eval_return = self.evaluate(*inputs)
if isinstance(eval_return, Number):
return [eval_return]
elif isinstance(eval_return, (list, tuple)):
return eval_return
else:
raise TypeError("Incorrect returned type from evaluate function.")
def inputs(self) -> List["InputPort"]:
return self._input_ports.copy()
......@@ -143,33 +149,6 @@ class AbstractOperation(Operation, AbstractGraphComponent):
assert name in self._parameters # TODO: Error message.
self._parameters[name] = value
def evaluate_outputs(self, state: SimulationState) -> List[Number]:
# TODO: Check implementation.
input_count: int = self.input_count()
output_count: int = self.output_count()
assert input_count == len(self._input_ports) # TODO: Error message.
assert output_count == len(self._output_ports) # TODO: Error message.
self_state: OperationState = state.operation_states[self]
while self_state.iteration < state.iteration:
input_values: List[Number] = [0] * input_count
for i in range(input_count):
source: Signal = self._input_ports[i].signal
input_values[i] = source.operation.evaluate_outputs(state)[
source.port_index]
self_state.output_values = self.evaluate(input_values)
# TODO: Error message.
assert len(self_state.output_values) == output_count
self_state.iteration += 1
for i in range(output_count):
for signal in self._output_ports[i].signals():
destination: Signal = signal.destination
destination.evaluate_outputs(state)
return self_state.output_values
def split(self) -> List[Operation]:
# TODO: Check implementation.
results = self.evaluate(self._input_ports)
......@@ -265,4 +244,3 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return ConstantDivision(other, self.output(0))
else:
raise TypeError("Other type is not an Operation or a Number.")
......@@ -11,6 +11,7 @@ from b_asic.signal import Signal
PortIndex = NewType("PortIndex", int)
class Port(ABC):
"""Port Interface.
......@@ -126,6 +127,7 @@ class InputPort(AbstractPort):
@property
def value_length(self) -> Optional[int]:
"""Return the InputPorts value length."""
return self._value_length
@property
......@@ -144,7 +146,8 @@ class InputPort(AbstractPort):
def connect(self, port: "OutputPort") -> Signal:
assert self._source_signal is None, "Connecting new port to already connected input port."
return Signal(port, self) # self._source_signal is set by the signal constructor.
# self._source_signal is set by the signal constructor.
return Signal(port, self)
def add_signal(self, signal: Signal) -> None:
assert self._source_signal is None, "Connecting new port to already connected input port."
......@@ -183,24 +186,21 @@ class OutputPort(AbstractPort):
def signals(self) -> List[Signal]:
return self._destination_signals.copy()
def signal(self, i: int = 0) -> Signal:
assert 0 <= i < self.signal_count(), "Signal index out of bounds."
return self._destination_signals[i]
@property
def connected_ports(self) -> List[Port]:
return [signal.destination for signal in self._destination_signals \
if signal.destination is not None]
return [signal.destination for signal in self._destination_signals
if signal.destination is not None]
def signal_count(self) -> int:
return len(self._destination_signals)
def connect(self, port: InputPort) -> Signal:
return Signal(self, port) # Signal is added to self._destination_signals in signal constructor.
# Signal is added to self._destination_signals in signal constructor.
return Signal(self, port)
def add_signal(self, signal: Signal) -> None:
assert signal not in self.signals, \
"Attempting to connect to Signal already connected."
"Attempting to connect to Signal already connected."
self._destination_signals.append(signal)
if self is not signal.source:
# Connect this outputport to the signal if it isn't already.
......
......@@ -15,8 +15,8 @@ class Signal(AbstractGraphComponent):
_source: "OutputPort"
_destination: "InputPort"
def __init__(self, source: Optional["OutputPort"] = None, \
destination: Optional["InputPort"] = None, name: Name = ""):
def __init__(self, source: Optional["OutputPort"] = None,
destination: Optional["InputPort"] = None, name: Name = ""):
super().__init__(name)
......
......@@ -4,7 +4,7 @@ TODO: More info.
"""
from numbers import Number
from typing import List
from typing import List, Dict
class OperationState:
......@@ -25,11 +25,19 @@ class SimulationState:
TODO: More info.
"""
# operation_states: Dict[OperationId, OperationState]
operation_states: Dict[int, OperationState]
iteration: int
def __init__(self):
self.operation_states = {}
op_state = OperationState()
self.operation_states = {1: op_state}
self.iteration = 0
# TODO: More stuff.
# @property
# #def iteration(self):
# return self.iteration
# @iteration.setter
# def iteration(self, new_iteration: int):
# self.iteration = new_iteration
#
# TODO: More stuff
......@@ -2,226 +2,313 @@
B-ASIC test suite for the core operations.
"""
from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision
from b_asic.core_operations import Constant, Addition, Subtraction, \
Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \
Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \
ConstantDivision, Butterfly
# Constant tests.
def test_constant():
constant_operation = Constant(3)
assert constant_operation.evaluate() == 3
def test_constant_negative():
constant_operation = Constant(-3)
assert constant_operation.evaluate() == -3
def test_constant_complex():
constant_operation = Constant(3+4j)
assert constant_operation.evaluate() == 3+4j
# Addition tests.
def test_addition():
test_operation = Addition()
constant_operation = Constant(3)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 8
def test_addition_negative():
test_operation = Addition()
constant_operation = Constant(-3)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -8
def test_addition_complex():
test_operation = Addition()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j)
# Subtraction tests.
def test_subtraction():
test_operation = Subtraction()
constant_operation = Constant(5)
constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 2
def test_subtraction_negative():
test_operation = Subtraction()
constant_operation = Constant(-5)
constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -2
def test_subtraction_complex():
test_operation = Subtraction()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j)
# Multiplication tests.
def test_multiplication():
test_operation = Multiplication()
constant_operation = Constant(5)
constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_negative():
test_operation = Multiplication()
constant_operation = Constant(-5)
constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_complex():
test_operation = Multiplication()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j)
# Division tests.
def test_division():
test_operation = Division()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_negative():
test_operation = Division()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_complex():
test_operation = Division()
constant_operation = Constant((60+40j))
constant_operation_2 = Constant((10+20j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j)
# SquareRoot tests.
def test_squareroot():
test_operation = SquareRoot()
constant_operation = Constant(36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6
def test_squareroot_negative():
test_operation = SquareRoot()
constant_operation = Constant(-36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6j
def test_squareroot_complex():
test_operation = SquareRoot()
constant_operation = Constant((48+64j))
assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j)
# ComplexConjugate tests.
def test_complexconjugate():
test_operation = ComplexConjugate()
constant_operation = Constant(3+4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j)
def test_test_complexconjugate_negative():
test_operation = ComplexConjugate()
constant_operation = Constant(-3-4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j)
# Max tests.
def test_max():
test_operation = Max()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 30
def test_max_negative():
test_operation = Max()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -5
# Min tests.
def test_min():
test_operation = Min()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 5
def test_min_negative():
test_operation = Min()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -30
# Absolute tests.
def test_absolute():
test_operation = Absolute()
constant_operation = Constant(30)
assert test_operation.evaluate(constant_operation.evaluate()) == 30
def test_absolute_negative():
test_operation = Absolute()
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == 5
def test_absolute_complex():
test_operation = Absolute()
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == 5.0
# ConstantMultiplication tests.
def test_constantmultiplication():
test_operation = ConstantMultiplication(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 100
def test_constantmultiplication_negative():
test_operation = ConstantMultiplication(5)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -25
def test_constantmultiplication_complex():
test_operation = ConstantMultiplication(3+2j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j)
# ConstantAddition tests.
def test_constantaddition():
test_operation = ConstantAddition(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 25
def test_constantaddition_negative():
test_operation = ConstantAddition(4)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -1
def test_constantaddition_complex():
test_operation = ConstantAddition(3+2j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j)
# ConstantSubtraction tests.
def test_constantsubtraction():
test_operation = ConstantSubtraction(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 15
def test_constantsubtraction_negative():
test_operation = ConstantSubtraction(4)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -9
def test_constantsubtraction_complex():
test_operation = ConstantSubtraction(4+6j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j)
# ConstantDivision tests.
def test_constantdivision():
test_operation = ConstantDivision(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 4
def test_constantdivision_negative():
test_operation = ConstantDivision(4)
constant_operation = Constant(-20)
assert test_operation.evaluate(constant_operation.evaluate()) == -5
def test_constantdivision_complex():
test_operation = ConstantDivision(2+2j)
constant_operation = Constant((10+10j))
assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j)
def test_butterfly():
test_operation = Butterfly()
assert list(test_operation.evaluate(2, 3)) == [5, -1]
def test_butterfly_negative():
test_operation = Butterfly()
assert list(test_operation.evaluate(-2, -3)) == [-5, 1]
def test_buttefly_complex():
test_operation = Butterfly()
assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j]
from b_asic.core_operations import Constant, Addition
from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly
from b_asic.signal import Signal
from b_asic.port import InputPort, OutputPort
import pytest
class TestTraverse:
def test_traverse_single_tree(self, operation):
"""Traverse a tree consisting of one operation."""
......@@ -20,8 +21,10 @@ class TestTraverse:
def test_traverse_type(self, large_operation_tree):
traverse = list(large_operation_tree.traverse())
assert len(list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3
assert len(list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4
assert len(
list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3
assert len(
list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4
def test_traverse_loop(self, operation_tree):
add_oper_signal = Signal()
......@@ -29,3 +32,43 @@ class TestTraverse:
operation_tree._input_ports[0].remove_signal(add_oper_signal)
operation_tree._input_ports[0].add_signal(add_oper_signal)
assert len(list(operation_tree.traverse())) == 2
class TestEvaluateOutput:
def test_evaluate_output_two_real_inputs(self):
"""Test evaluate_output for two real numbered inputs."""
add1 = Addition()
assert list(add1.evaluate_output(0, [1, 2])) == [3]
def test_evaluate_output_addition_two_complex_inputs(self):
"""Test evaluate_output for two complex numbered inputs."""
add1 = Addition()
assert list(add1.evaluate_output(0, [1+1j, 2])) == [3+1j]
def test_evaluate_output_one_real_input(self):
"""Test evaluate_output for one real numbered inputs."""
c_add1 = ConstantAddition(5)
assert list(c_add1.evaluate_output(0, [1])) == [6]
def test_evaluate_output_one_complex_input(self):
"""Test evaluate_output for one complex numbered inputs."""
c_add1 = ConstantAddition(5)
assert list(c_add1.evaluate_output(0, [1+1j])) == [6+1j]
def test_evaluate_output_two_real_inputs_two_outputs(self):
"""Test evaluate_output for two real inputs and two outputs."""
bfly1 = Butterfly()
assert list(bfly1.evaluate_output(0, [6, 9])) == [15, -3]
assert list(bfly1.evaluate_output(1, [6, 9])) == [15, -3]
def test_evaluate_output_two_complex_inputs_two_outputs(self):
"""Test evaluate_output for two complex inputs and two outputs."""
bfly1 = Butterfly()
assert list(bfly1.evaluate_output(0, [3+2j, 4+2j])) == [7+4j, -1]
assert list(bfly1.evaluate_output(1, [3+2j, 4+2j])) == [7+4j, -1]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment