diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 8902b169c07e600a843e526d44452fb9386ff7a9..e837d0070db95a0430671b8ae91235aa50ca829f 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -4,10 +4,8 @@ TODO: More info. """ from numbers import Number -from typing import Any from numpy import conjugate, sqrt, abs as np_abs from b_asic.port import InputPort, OutputPort -from b_asic.graph_id import GraphIDType from b_asic.operation import AbstractOperation from b_asic.graph_component import Name, TypeName @@ -335,3 +333,28 @@ class ConstantDivision(AbstractOperation): @property def type_name(self) -> TypeName: return "cdiv" + + +class Butterfly(AbstractOperation): + """Butterfly operation that returns two outputs. + The first output is a + b and the second output is a - b. + TODO: More info. + """ + + def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): + super().__init__(name) + self._input_ports = [InputPort(0, self), InputPort(1, self)] + self._output_ports = [OutputPort(0, self), OutputPort(1, self)] + + if source1 is not None: + self._input_ports[0].connect(source1) + + if source2 is not None: + self._input_ports[1].connect(source2) + + def evaluate(self, a, b): + return a + b, a - b + + @property + def type_name(self) -> TypeName: + return "bfly" diff --git a/b_asic/operation.py b/b_asic/operation.py index 5578e3c48edcf15594d6d1cd71e71a17521eca25..3dd761c2cd18bc97aac5e15bfd85575ed84da8f9 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -5,12 +5,10 @@ TODO: More info. from abc import abstractmethod from numbers import Number -from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING +from typing import List, Dict, Optional, Any, Set, Sequence, TYPE_CHECKING from collections import deque from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name -from b_asic.simulation import SimulationState, OperationState -from b_asic.signal import Signal if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort @@ -51,6 +49,12 @@ class Operation(GraphComponent): """Get the output port at index i.""" raise NotImplementedError + @abstractmethod + def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]: + """Evaluate the output port at the entered index with the entered input values and + returns all output values that are calulated during the evaluation in a list.""" + raise NotImplementedError + @abstractmethod def params(self) -> Dict[str, Optional[Any]]: """Get a dictionary of all parameter values.""" @@ -70,13 +74,6 @@ class Operation(GraphComponent): """ raise NotImplementedError - @abstractmethod - def evaluate_outputs(self, state: "SimulationState") -> List[Number]: - """Simulate the circuit until its iteration count matches that of the simulation state, - then return the resulting output vector. - """ - raise NotImplementedError - @abstractmethod def split(self) -> "List[Operation]": """Split the operation into multiple operations. @@ -115,6 +112,15 @@ class AbstractOperation(Operation, AbstractGraphComponent): """ raise NotImplementedError + def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]: + eval_return = self.evaluate(*inputs) + if isinstance(eval_return, Number): + return [eval_return] + elif isinstance(eval_return, (list, tuple)): + return eval_return + else: + raise TypeError("Incorrect returned type from evaluate function.") + def inputs(self) -> List["InputPort"]: return self._input_ports.copy() @@ -143,33 +149,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): assert name in self._parameters # TODO: Error message. self._parameters[name] = value - def evaluate_outputs(self, state: SimulationState) -> List[Number]: - # TODO: Check implementation. - input_count: int = self.input_count() - output_count: int = self.output_count() - assert input_count == len(self._input_ports) # TODO: Error message. - assert output_count == len(self._output_ports) # TODO: Error message. - - self_state: OperationState = state.operation_states[self] - - while self_state.iteration < state.iteration: - input_values: List[Number] = [0] * input_count - for i in range(input_count): - source: Signal = self._input_ports[i].signal - input_values[i] = source.operation.evaluate_outputs(state)[ - source.port_index] - - self_state.output_values = self.evaluate(input_values) - # TODO: Error message. - assert len(self_state.output_values) == output_count - self_state.iteration += 1 - for i in range(output_count): - for signal in self._output_ports[i].signals(): - destination: Signal = signal.destination - destination.evaluate_outputs(state) - - return self_state.output_values - def split(self) -> List[Operation]: # TODO: Check implementation. results = self.evaluate(self._input_ports) @@ -265,4 +244,3 @@ class AbstractOperation(Operation, AbstractGraphComponent): return ConstantDivision(other, self.output(0)) else: raise TypeError("Other type is not an Operation or a Number.") - diff --git a/b_asic/port.py b/b_asic/port.py index 64c56bc9aef2ca71c78d08ccf0b8c20c1c68fd61..5900afb1adab0252386a70a46d9cbc7202a99b28 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -11,6 +11,7 @@ from b_asic.signal import Signal PortIndex = NewType("PortIndex", int) + class Port(ABC): """Port Interface. @@ -126,6 +127,7 @@ class InputPort(AbstractPort): @property def value_length(self) -> Optional[int]: + """Return the InputPorts value length.""" return self._value_length @property @@ -144,7 +146,8 @@ class InputPort(AbstractPort): def connect(self, port: "OutputPort") -> Signal: assert self._source_signal is None, "Connecting new port to already connected input port." - return Signal(port, self) # self._source_signal is set by the signal constructor. + # self._source_signal is set by the signal constructor. + return Signal(port, self) def add_signal(self, signal: Signal) -> None: assert self._source_signal is None, "Connecting new port to already connected input port." @@ -183,24 +186,21 @@ class OutputPort(AbstractPort): def signals(self) -> List[Signal]: return self._destination_signals.copy() - def signal(self, i: int = 0) -> Signal: - assert 0 <= i < self.signal_count(), "Signal index out of bounds." - return self._destination_signals[i] - @property def connected_ports(self) -> List[Port]: - return [signal.destination for signal in self._destination_signals \ - if signal.destination is not None] + return [signal.destination for signal in self._destination_signals + if signal.destination is not None] def signal_count(self) -> int: return len(self._destination_signals) def connect(self, port: InputPort) -> Signal: - return Signal(self, port) # Signal is added to self._destination_signals in signal constructor. + # Signal is added to self._destination_signals in signal constructor. + return Signal(self, port) def add_signal(self, signal: Signal) -> None: assert signal not in self.signals, \ - "Attempting to connect to Signal already connected." + "Attempting to connect to Signal already connected." self._destination_signals.append(signal) if self is not signal.source: # Connect this outputport to the signal if it isn't already. diff --git a/b_asic/signal.py b/b_asic/signal.py index 64c259486abd78e3b18ea824b58bfabe271f50d8..460bf5db9c7335adf52914610ab95381d6aed17c 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -15,8 +15,8 @@ class Signal(AbstractGraphComponent): _source: "OutputPort" _destination: "InputPort" - def __init__(self, source: Optional["OutputPort"] = None, \ - destination: Optional["InputPort"] = None, name: Name = ""): + def __init__(self, source: Optional["OutputPort"] = None, + destination: Optional["InputPort"] = None, name: Name = ""): super().__init__(name) diff --git a/b_asic/simulation.py b/b_asic/simulation.py index 50adaa522b6d685b428354a9f84689330b7fd40f..a2ce11b3263d517cba79c92093e594d712c5b8f3 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -4,7 +4,7 @@ TODO: More info. """ from numbers import Number -from typing import List +from typing import List, Dict class OperationState: @@ -25,11 +25,19 @@ class SimulationState: TODO: More info. """ - # operation_states: Dict[OperationId, OperationState] + operation_states: Dict[int, OperationState] iteration: int def __init__(self): - self.operation_states = {} + op_state = OperationState() + self.operation_states = {1: op_state} self.iteration = 0 - # TODO: More stuff. + # @property + # #def iteration(self): + # return self.iteration + # @iteration.setter + # def iteration(self, new_iteration: int): + # self.iteration = new_iteration + # + # TODO: More stuff diff --git a/test/test_core_operations.py b/test/test_core_operations.py index b176b2a6506cc5a1297813f6ddcb6d3589492838..854ccf85f447e430af303dc9a45c8946ac8d7828 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -2,226 +2,313 @@ B-ASIC test suite for the core operations. """ -from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision +from b_asic.core_operations import Constant, Addition, Subtraction, \ + Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ + Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ + ConstantDivision, Butterfly # Constant tests. + + def test_constant(): constant_operation = Constant(3) assert constant_operation.evaluate() == 3 + def test_constant_negative(): constant_operation = Constant(-3) assert constant_operation.evaluate() == -3 + def test_constant_complex(): constant_operation = Constant(3+4j) assert constant_operation.evaluate() == 3+4j # Addition tests. + + def test_addition(): test_operation = Addition() constant_operation = Constant(3) constant_operation_2 = Constant(5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 + def test_addition_negative(): test_operation = Addition() constant_operation = Constant(-3) constant_operation_2 = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 + def test_addition_complex(): test_operation = Addition() constant_operation = Constant((3+5j)) constant_operation_2 = Constant((4+6j)) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) # Subtraction tests. + + def test_subtraction(): test_operation = Subtraction() constant_operation = Constant(5) constant_operation_2 = Constant(3) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 + def test_subtraction_negative(): test_operation = Subtraction() constant_operation = Constant(-5) constant_operation_2 = Constant(-3) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 + def test_subtraction_complex(): test_operation = Subtraction() constant_operation = Constant((3+5j)) constant_operation_2 = Constant((4+6j)) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) # Multiplication tests. + + def test_multiplication(): test_operation = Multiplication() constant_operation = Constant(5) constant_operation_2 = Constant(3) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + def test_multiplication_negative(): test_operation = Multiplication() constant_operation = Constant(-5) constant_operation_2 = Constant(-3) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + def test_multiplication_complex(): test_operation = Multiplication() constant_operation = Constant((3+5j)) constant_operation_2 = Constant((4+6j)) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) # Division tests. + + def test_division(): test_operation = Division() constant_operation = Constant(30) constant_operation_2 = Constant(5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + def test_division_negative(): test_operation = Division() constant_operation = Constant(-30) constant_operation_2 = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + def test_division_complex(): test_operation = Division() constant_operation = Constant((60+40j)) constant_operation_2 = Constant((10+20j)) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) # SquareRoot tests. + + def test_squareroot(): test_operation = SquareRoot() constant_operation = Constant(36) assert test_operation.evaluate(constant_operation.evaluate()) == 6 + def test_squareroot_negative(): test_operation = SquareRoot() constant_operation = Constant(-36) assert test_operation.evaluate(constant_operation.evaluate()) == 6j + def test_squareroot_complex(): test_operation = SquareRoot() constant_operation = Constant((48+64j)) assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) # ComplexConjugate tests. + + def test_complexconjugate(): test_operation = ComplexConjugate() constant_operation = Constant(3+4j) assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) + def test_test_complexconjugate_negative(): test_operation = ComplexConjugate() constant_operation = Constant(-3-4j) assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) # Max tests. + + def test_max(): test_operation = Max() constant_operation = Constant(30) constant_operation_2 = Constant(5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 + def test_max_negative(): test_operation = Max() constant_operation = Constant(-30) constant_operation_2 = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 # Min tests. + + def test_min(): test_operation = Min() constant_operation = Constant(30) constant_operation_2 = Constant(5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 + def test_min_negative(): test_operation = Min() constant_operation = Constant(-30) constant_operation_2 = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 # Absolute tests. + + def test_absolute(): test_operation = Absolute() constant_operation = Constant(30) assert test_operation.evaluate(constant_operation.evaluate()) == 30 + def test_absolute_negative(): test_operation = Absolute() constant_operation = Constant(-5) assert test_operation.evaluate(constant_operation.evaluate()) == 5 + def test_absolute_complex(): test_operation = Absolute() constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 # ConstantMultiplication tests. + + def test_constantmultiplication(): test_operation = ConstantMultiplication(5) constant_operation = Constant(20) assert test_operation.evaluate(constant_operation.evaluate()) == 100 + def test_constantmultiplication_negative(): test_operation = ConstantMultiplication(5) constant_operation = Constant(-5) assert test_operation.evaluate(constant_operation.evaluate()) == -25 + def test_constantmultiplication_complex(): test_operation = ConstantMultiplication(3+2j) constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) # ConstantAddition tests. + + def test_constantaddition(): test_operation = ConstantAddition(5) constant_operation = Constant(20) assert test_operation.evaluate(constant_operation.evaluate()) == 25 + def test_constantaddition_negative(): test_operation = ConstantAddition(4) constant_operation = Constant(-5) assert test_operation.evaluate(constant_operation.evaluate()) == -1 + def test_constantaddition_complex(): test_operation = ConstantAddition(3+2j) constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) # ConstantSubtraction tests. + + def test_constantsubtraction(): test_operation = ConstantSubtraction(5) constant_operation = Constant(20) assert test_operation.evaluate(constant_operation.evaluate()) == 15 + def test_constantsubtraction_negative(): test_operation = ConstantSubtraction(4) constant_operation = Constant(-5) assert test_operation.evaluate(constant_operation.evaluate()) == -9 + def test_constantsubtraction_complex(): test_operation = ConstantSubtraction(4+6j) constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) # ConstantDivision tests. + + def test_constantdivision(): test_operation = ConstantDivision(5) constant_operation = Constant(20) assert test_operation.evaluate(constant_operation.evaluate()) == 4 + def test_constantdivision_negative(): test_operation = ConstantDivision(4) constant_operation = Constant(-20) assert test_operation.evaluate(constant_operation.evaluate()) == -5 + def test_constantdivision_complex(): test_operation = ConstantDivision(2+2j) constant_operation = Constant((10+10j)) assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) + + +def test_butterfly(): + test_operation = Butterfly() + assert list(test_operation.evaluate(2, 3)) == [5, -1] + + +def test_butterfly_negative(): + test_operation = Butterfly() + assert list(test_operation.evaluate(-2, -3)) == [-5, 1] + + +def test_buttefly_complex(): + test_operation = Butterfly() + assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j] diff --git a/test/test_operation.py b/test/test_operation.py index 6c37e30bddd0b55ea69ae5b95a341c1ddeb56847..5891f3f8038bcf1aa451ff43092989ddb7bc8196 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,9 +1,10 @@ -from b_asic.core_operations import Constant, Addition +from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly from b_asic.signal import Signal from b_asic.port import InputPort, OutputPort import pytest + class TestTraverse: def test_traverse_single_tree(self, operation): """Traverse a tree consisting of one operation.""" @@ -20,8 +21,10 @@ class TestTraverse: def test_traverse_type(self, large_operation_tree): traverse = list(large_operation_tree.traverse()) - assert len(list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 - assert len(list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 + assert len( + list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 + assert len( + list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 def test_traverse_loop(self, operation_tree): add_oper_signal = Signal() @@ -29,3 +32,43 @@ class TestTraverse: operation_tree._input_ports[0].remove_signal(add_oper_signal) operation_tree._input_ports[0].add_signal(add_oper_signal) assert len(list(operation_tree.traverse())) == 2 + + +class TestEvaluateOutput: + def test_evaluate_output_two_real_inputs(self): + """Test evaluate_output for two real numbered inputs.""" + add1 = Addition() + + assert list(add1.evaluate_output(0, [1, 2])) == [3] + + def test_evaluate_output_addition_two_complex_inputs(self): + """Test evaluate_output for two complex numbered inputs.""" + add1 = Addition() + + assert list(add1.evaluate_output(0, [1+1j, 2])) == [3+1j] + + def test_evaluate_output_one_real_input(self): + """Test evaluate_output for one real numbered inputs.""" + c_add1 = ConstantAddition(5) + + assert list(c_add1.evaluate_output(0, [1])) == [6] + + def test_evaluate_output_one_complex_input(self): + """Test evaluate_output for one complex numbered inputs.""" + c_add1 = ConstantAddition(5) + + assert list(c_add1.evaluate_output(0, [1+1j])) == [6+1j] + + def test_evaluate_output_two_real_inputs_two_outputs(self): + """Test evaluate_output for two real inputs and two outputs.""" + bfly1 = Butterfly() + + assert list(bfly1.evaluate_output(0, [6, 9])) == [15, -3] + assert list(bfly1.evaluate_output(1, [6, 9])) == [15, -3] + + def test_evaluate_output_two_complex_inputs_two_outputs(self): + """Test evaluate_output for two complex inputs and two outputs.""" + bfly1 = Butterfly() + + assert list(bfly1.evaluate_output(0, [3+2j, 4+2j])) == [7+4j, -1] + assert list(bfly1.evaluate_output(1, [3+2j, 4+2j])) == [7+4j, -1]