Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
Commits on Source (2)
Showing
with 1080 additions and 927 deletions
...@@ -30,12 +30,12 @@ class Constant(AbstractOperation): ...@@ -30,12 +30,12 @@ class Constant(AbstractOperation):
@property @property
def value(self) -> Number: def value(self) -> Number:
"""TODO: docstring""" """Get the constant value of this operation."""
return self.param("value") return self.param("value")
@value.setter @value.setter
def value(self, value: Number): def value(self, value: Number) -> None:
"""TODO: docstring""" """Set the constant value of this operation."""
return self.set_param("value", value) return self.set_param("value", value)
...@@ -103,36 +103,22 @@ class Division(AbstractOperation): ...@@ -103,36 +103,22 @@ class Division(AbstractOperation):
return a / b return a / b
class SquareRoot(AbstractOperation): class Min(AbstractOperation):
"""Unary square root operation. """Binary min operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
@property
def type_name(self) -> TypeName:
return "sqrt"
def evaluate(self, a):
return sqrt(complex(a))
class ComplexConjugate(AbstractOperation):
"""Unary complex conjugate operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "conj" return "min"
def evaluate(self, a): def evaluate(self, a, b):
return conjugate(a) assert not isinstance(a, complex) and not isinstance(b, complex), \
("core_operations.Min does not support complex numbers.")
return a if a < b else b
class Max(AbstractOperation): class Max(AbstractOperation):
...@@ -153,26 +139,8 @@ class Max(AbstractOperation): ...@@ -153,26 +139,8 @@ class Max(AbstractOperation):
return a if a > b else b return a if a > b else b
class Min(AbstractOperation): class SquareRoot(AbstractOperation):
"""Binary min operation. """Unary square root operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
@property
def type_name(self) -> TypeName:
return "min"
def evaluate(self, a, b):
assert not isinstance(a, complex) and not isinstance(b, complex), \
("core_operations.Min does not support complex numbers.")
return a if a < b else b
class Absolute(AbstractOperation):
"""Unary absolute value operation.
TODO: More info. TODO: More info.
""" """
...@@ -181,48 +149,46 @@ class Absolute(AbstractOperation): ...@@ -181,48 +149,46 @@ class Absolute(AbstractOperation):
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "abs" return "sqrt"
def evaluate(self, a): def evaluate(self, a):
return np_abs(a) return sqrt(complex(a))
class ConstantMultiplication(AbstractOperation): class ComplexConjugate(AbstractOperation):
"""Unary constant multiplication operation. """Unary complex conjugate operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("value", value)
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "cmul" return "conj"
def evaluate(self, a): def evaluate(self, a):
return a * self.param("value") return conjugate(a)
class ConstantAddition(AbstractOperation): class Absolute(AbstractOperation):
"""Unary constant addition operation. """Unary absolute value operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("value", value)
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "cadd" return "abs"
def evaluate(self, a): def evaluate(self, a):
return a + self.param("value") return np_abs(a)
class ConstantSubtraction(AbstractOperation): class ConstantMultiplication(AbstractOperation):
"""Unary constant subtraction operation. """Unary constant multiplication operation.
TODO: More info. TODO: More info.
""" """
...@@ -232,27 +198,21 @@ class ConstantSubtraction(AbstractOperation): ...@@ -232,27 +198,21 @@ class ConstantSubtraction(AbstractOperation):
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "csub" return "cmul"
def evaluate(self, a): def evaluate(self, a):
return a - self.param("value") return a * self.param("value")
class ConstantDivision(AbstractOperation):
"""Unary constant division operation.
TODO: More info.
"""
def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("value", value)
@property @property
def type_name(self) -> TypeName: def value(self) -> Number:
return "cdiv" """Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: Number) -> None:
"""Set the constant value of this operation."""
return self.set_param("value", value)
def evaluate(self, a):
return a / self.param("value")
class Butterfly(AbstractOperation): class Butterfly(AbstractOperation):
"""Butterfly operation that returns two outputs. """Butterfly operation that returns two outputs.
...@@ -263,9 +223,9 @@ class Butterfly(AbstractOperation): ...@@ -263,9 +223,9 @@ class Butterfly(AbstractOperation):
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 2, output_count = 2, name = name, input_sources = [src0, src1]) super().__init__(input_count = 2, output_count = 2, name = name, input_sources = [src0, src1])
def evaluate(self, a, b):
return a + b, a - b
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "bfly" return "bfly"
def evaluate(self, a, b):
return a + b, a - b
...@@ -4,11 +4,15 @@ TODO: More info. ...@@ -4,11 +4,15 @@ TODO: More info.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import copy from collections import deque
from typing import NewType from copy import copy, deepcopy
from typing import NewType, Any, Dict, Mapping, Iterable, Generator
Name = NewType("Name", str) Name = NewType("Name", str)
TypeName = NewType("TypeName", str) TypeName = NewType("TypeName", str)
GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)
class GraphComponent(ABC): class GraphComponent(ABC):
...@@ -19,37 +23,87 @@ class GraphComponent(ABC): ...@@ -19,37 +23,87 @@ class GraphComponent(ABC):
@property @property
@abstractmethod @abstractmethod
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
"""Return the type name of the graph component""" """Get the type name of this graph component"""
raise NotImplementedError raise NotImplementedError
@property @property
@abstractmethod @abstractmethod
def name(self) -> Name: def name(self) -> Name:
"""Return the name of the graph component.""" """Get the name of this graph component."""
raise NotImplementedError raise NotImplementedError
@name.setter @name.setter
@abstractmethod @abstractmethod
def name(self, name: Name) -> None: def name(self, name: Name) -> None:
"""Set the name of the graph component to the entered name.""" """Set the name of this graph component to the given name."""
raise NotImplementedError
@property
@abstractmethod
def graph_id(self) -> GraphID:
"""Get the graph id of this graph component."""
raise NotImplementedError raise NotImplementedError
@graph_id.setter
@abstractmethod @abstractmethod
def copy_unconnected(self) -> "GraphComponent": def graph_id(self, graph_id: GraphID) -> None:
"""Get a copy of this graph component, except without any connected components.""" """Set the graph id of this graph component to the given id.
Note that this id will be ignored if this component is used to create a new graph,
and that a new local id will be generated for it instead."""
raise NotImplementedError
@property
@abstractmethod
def params(self) -> Mapping[str, Any]:
"""Get a dictionary of all parameter values."""
raise NotImplementedError
@abstractmethod
def param(self, name: str) -> Any:
"""Get the value of a parameter.
Returns None if the parameter is not defined.
"""
raise NotImplementedError
@abstractmethod
def set_param(self, name: str, value: Any) -> None:
"""Set the value of a parameter.
Adds the parameter if it is not already defined.
"""
raise NotImplementedError
@abstractmethod
def copy_component(self, *args, **kwargs) -> "GraphComponent":
"""Get a new instance of this graph component type with the same name, id and parameters."""
raise NotImplementedError
@property
@abstractmethod
def neighbors(self) -> Iterable["GraphComponent"]:
"""Get all components that are directly connected to this operation."""
raise NotImplementedError
@abstractmethod
def traverse(self) -> Generator["GraphComponent", None, None]:
"""Get a generator that recursively iterates through all components that are connected to this operation,
as well as the ones that they are connected to.
"""
raise NotImplementedError raise NotImplementedError
class AbstractGraphComponent(GraphComponent): class AbstractGraphComponent(GraphComponent):
"""Abstract Graph Component class which is a component of a signal flow graph. """Abstract Graph Component class which is a component of a signal flow graph.
TODO: More info. TODO: More info.
""" """
_name: Name _name: Name
_graph_id: GraphID
_parameters: Dict[str, Any]
def __init__(self, name: Name = ""): def __init__(self, name: Name = ""):
self._name = name self._name = name
self._graph_id = ""
self._parameters = {}
@property @property
def name(self) -> Name: def name(self) -> Name:
...@@ -58,8 +112,41 @@ class AbstractGraphComponent(GraphComponent): ...@@ -58,8 +112,41 @@ class AbstractGraphComponent(GraphComponent):
@name.setter @name.setter
def name(self, name: Name) -> None: def name(self, name: Name) -> None:
self._name = name self._name = name
@property
def graph_id(self) -> GraphID:
return self._graph_id
@graph_id.setter
def graph_id(self, graph_id: GraphID) -> None:
self._graph_id = graph_id
def copy_unconnected(self) -> GraphComponent: @property
new_comp = self.__class__() def params(self) -> Mapping[str, Any]:
new_comp.name = copy(self.name) return self._parameters.copy()
return new_comp
\ No newline at end of file def param(self, name: str) -> Any:
return self._parameters.get(name)
def set_param(self, name: str, value: Any) -> None:
self._parameters[name] = value
def copy_component(self, *args, **kwargs) -> GraphComponent:
new_component = self.__class__(*args, **kwargs)
new_component.name = copy(self.name)
new_component.graph_id = copy(self.graph_id)
for name, value in self.params.items():
new_component.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member
return new_component
def traverse(self) -> Generator[GraphComponent, None, None]:
# Breadth first search.
visited = {self}
fontier = deque([self])
while fontier:
component = fontier.popleft()
yield component
for neighbor in component.neighbors:
if neighbor not in visited:
visited.add(neighbor)
fontier.append(neighbor)
\ No newline at end of file
This diff is collapsed.
...@@ -108,12 +108,10 @@ class InputPort(AbstractPort): ...@@ -108,12 +108,10 @@ class InputPort(AbstractPort):
""" """
_source_signal: Optional[Signal] _source_signal: Optional[Signal]
_value_length: Optional[int]
def __init__(self, operation: "Operation", index: int): def __init__(self, operation: "Operation", index: int):
super().__init__(operation, index) super().__init__(operation, index)
self._source_signal = None self._source_signal = None
self._value_length = None
@property @property
def signal_count(self) -> int: def signal_count(self) -> int:
...@@ -153,18 +151,6 @@ class InputPort(AbstractPort): ...@@ -153,18 +151,6 @@ class InputPort(AbstractPort):
# self._source_signal is set by the signal constructor. # self._source_signal is set by the signal constructor.
return Signal(source=src.source, destination=self, name=name) return Signal(source=src.source, destination=self, name=name)
@property
def value_length(self) -> Optional[int]:
"""Get the number of bits that this port should truncate received values to."""
return self._value_length
@value_length.setter
def value_length(self, bits: Optional[int]) -> None:
"""Set the number of bits that this port should truncate received values to."""
assert bits is None or (isinstance(
bits, int) and bits >= 0), "Value length must be non-negative."
self._value_length = bits
class OutputPort(AbstractPort, SignalSourceProvider): class OutputPort(AbstractPort, SignalSourceProvider):
"""Output port. """Output port.
......
"""@package docstring """@package docstring
B-ASIC Signal Module. B-ASIC Signal Module.
""" """
from typing import Optional, TYPE_CHECKING from typing import Optional, Iterable, TYPE_CHECKING
from b_asic.graph_component import AbstractGraphComponent, TypeName, Name from b_asic.graph_component import GraphComponent, AbstractGraphComponent, TypeName, Name
if TYPE_CHECKING: if TYPE_CHECKING:
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
...@@ -15,8 +15,7 @@ class Signal(AbstractGraphComponent): ...@@ -15,8 +15,7 @@ class Signal(AbstractGraphComponent):
_source: Optional["OutputPort"] _source: Optional["OutputPort"]
_destination: Optional["InputPort"] _destination: Optional["InputPort"]
def __init__(self, source: Optional["OutputPort"] = None, \ def __init__(self, source: Optional["OutputPort"] = None, destination: Optional["InputPort"] = None, bits: Optional[int] = None, name: Name = ""):
destination: Optional["InputPort"] = None, name: Name = ""):
super().__init__(name) super().__init__(name)
self._source = None self._source = None
self._destination = None self._destination = None
...@@ -24,7 +23,16 @@ class Signal(AbstractGraphComponent): ...@@ -24,7 +23,16 @@ class Signal(AbstractGraphComponent):
self.set_source(source) self.set_source(source)
if destination is not None: if destination is not None:
self.set_destination(destination) self.set_destination(destination)
self.set_param("bits", bits)
@property
def type_name(self) -> TypeName:
return "s"
@property
def neighbors(self) -> Iterable[GraphComponent]:
return [p.operation for p in [self.source, self.destination] if p is not None]
@property @property
def source(self) -> Optional["OutputPort"]: def source(self) -> Optional["OutputPort"]:
"""Return the source OutputPort of the signal.""" """Return the source OutputPort of the signal."""
...@@ -63,10 +71,6 @@ class Signal(AbstractGraphComponent): ...@@ -63,10 +71,6 @@ class Signal(AbstractGraphComponent):
if self not in dest.signals: if self not in dest.signals:
dest.add_signal(self) dest.add_signal(self)
@property
def type_name(self) -> TypeName:
return "s"
def remove_source(self) -> None: def remove_source(self) -> None:
"""Disconnect the source OutputPort of the signal. If the source port """Disconnect the source OutputPort of the signal. If the source port
still is connected to this signal then also disconnect the source port.""" still is connected to this signal then also disconnect the source port."""
...@@ -88,3 +92,16 @@ class Signal(AbstractGraphComponent): ...@@ -88,3 +92,16 @@ class Signal(AbstractGraphComponent):
"""Returns true if the signal is missing either a source or a destination, """Returns true if the signal is missing either a source or a destination,
else false.""" else false."""
return self._source is None or self._destination is None return self._source is None or self._destination is None
@property
def bits(self) -> Optional[int]:
"""Get the number of bits that this operations using this signal as an input should truncate received values to.
None = unlimited."""
return self.param("bits")
@bits.setter
def bits(self, bits: Optional[int]) -> None:
"""Set the number of bits that operations using this signal as an input should truncate received values to.
None = unlimited."""
assert bits is None or (isinstance(bits, int) and bits >= 0), "Bits must be non-negative."
self.set_param("bits", bits)
\ No newline at end of file
This diff is collapsed.
...@@ -3,41 +3,111 @@ B-ASIC Simulation Module. ...@@ -3,41 +3,111 @@ B-ASIC Simulation Module.
TODO: More info. TODO: More info.
""" """
from collections import defaultdict
from numbers import Number from numbers import Number
from typing import List, Dict from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping, Union, Optional
from b_asic.operation import ResultKey, ResultMap
from b_asic.signal_flow_graph import SFG
class OperationState:
"""Simulation state of an operation. InputProvider = Union[Number, Sequence[Number], Callable[[int], Number]]
class Simulation:
"""Simulation.
TODO: More info. TODO: More info.
""" """
output_values: List[Number] _sfg: SFG
iteration: int _results: DefaultDict[int, Dict[str, Number]]
_registers: Dict[str, Number]
_iteration: int
_input_functions: Sequence[Callable[[int], Number]]
_current_input_values: Sequence[Number]
_latest_output_values: Sequence[Number]
_save_results: bool
def __init__(self): def __init__(self, sfg: SFG, input_providers: Optional[Sequence[Optional[InputProvider]]] = None, save_results: bool = False):
self.output_values = [] self._sfg = sfg
self.iteration = 0 self._results = defaultdict(dict)
self._registers = {}
self._iteration = 0
self._input_functions = [lambda _: 0 for _ in range(self._sfg.input_count)]
self._current_input_values = [0 for _ in range(self._sfg.input_count)]
self._latest_output_values = [0 for _ in range(self._sfg.output_count)]
self._save_results = save_results
if input_providers is not None:
self.set_inputs(input_providers)
def set_input(self, index: int, input_provider: InputProvider) -> None:
"""Set the input function used to get values for the specific input at the given index to the internal SFG."""
if index < 0 or index >= len(self._input_functions):
raise IndexError(f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})")
if callable(input_provider):
self._input_functions[index] = input_provider
elif isinstance(input_provider, Number):
self._input_functions[index] = lambda _: input_provider
else:
self._input_functions[index] = lambda n: input_provider[n]
class SimulationState: def set_inputs(self, input_providers: Sequence[Optional[InputProvider]]) -> None:
"""Simulation state. """Set the input functions used to get values for the inputs to the internal SFG."""
TODO: More info. if len(input_providers) != self._sfg.input_count:
""" raise ValueError(f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})")
self._input_functions = [None for _ in range(self._sfg.input_count)]
for index, input_provider in enumerate(input_providers):
if input_provider is not None:
self.set_input(index, input_provider)
@property
def save_results(self) -> bool:
"""Get the flag that determines if the results of ."""
return self._save_results
@save_results.setter
def save_results(self, save_results) -> None:
self._save_results = save_results
def run(self) -> Sequence[Number]:
"""Run one iteration of the simulation and return the resulting output values."""
return self.run_for(1)
def run_until(self, iteration: int) -> Sequence[Number]:
"""Run the simulation until its iteration is greater than or equal to the given iteration
and return the resulting output values.
"""
while self._iteration < iteration:
self._current_input_values = [self._input_functions[i](self._iteration) for i in range(self._sfg.input_count)]
self._latest_output_values = self._sfg.evaluate_outputs(self._current_input_values, self._results[self._iteration], self._registers)
if not self._save_results:
del self._results[self.iteration]
self._iteration += 1
return self._latest_output_values
def run_for(self, iterations: int) -> Sequence[Number]:
"""Run a given number of iterations of the simulation and return the resulting output values."""
return self.run_until(self._iteration + iterations)
@property
def iteration(self) -> int:
"""Get the current iteration number of the simulation."""
return self._iteration
@property
def results(self) -> Mapping[int, ResultMap]:
"""Get a mapping of all results, including intermediate values, calculated for each iteration up until now.
The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values.
Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}}
"""
return self._results
def clear_results(self) -> None:
"""Clear all results that were saved until now."""
self._results.clear()
operation_states: Dict[int, OperationState] def clear_state(self) -> None:
iteration: int """Clear all current state of the simulation, except for the results and iteration."""
self._registers.clear()
def __init__(self): self._current_input_values = [0 for _ in range(self._sfg.input_count)]
op_state = OperationState() self._latest_output_values = [0 for _ in range(self._sfg.output_count)]
self.operation_states = {1: op_state} \ No newline at end of file
self.iteration = 0
# @property
# #def iteration(self):
# return self.iteration
# @iteration.setter
# def iteration(self, new_iteration: int):
# self.iteration = new_iteration
#
# TODO: More stuff
...@@ -4,9 +4,9 @@ TODO: More info. ...@@ -4,9 +4,9 @@ TODO: More info.
""" """
from numbers import Number from numbers import Number
from typing import Optional from typing import Optional, Sequence
from b_asic.operation import AbstractOperation from b_asic.operation import AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap
from b_asic.graph_component import Name, TypeName from b_asic.graph_component import Name, TypeName
from b_asic.port import SignalSourceProvider from b_asic.port import SignalSourceProvider
...@@ -29,12 +29,12 @@ class Input(AbstractOperation): ...@@ -29,12 +29,12 @@ class Input(AbstractOperation):
@property @property
def value(self) -> Number: def value(self) -> Number:
"""TODO: docstring""" """Get the current value of this input."""
return self.param("value") return self.param("value")
@value.setter @value.setter
def value(self, value: Number): def value(self, value: Number) -> None:
"""TODO: docstring""" """Set the current value of this input."""
self.set_param("value", value) self.set_param("value", value)
...@@ -44,11 +44,48 @@ class Output(AbstractOperation): ...@@ -44,11 +44,48 @@ class Output(AbstractOperation):
""" """
def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 0, name = name, input_sources=[src0]) super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0])
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "out" return "out"
def evaluate(self): def evaluate(self, _):
return None return None
\ No newline at end of file
class Register(AbstractOperation):
"""Unit delay operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, initial_value: Number = 0, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("initial_value", initial_value)
@property
def type_name(self) -> TypeName:
return "reg"
def evaluate(self, a):
return self.param("initial_value")
def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]:
if registers is not None:
return registers.get(self.key(index, prefix), self.param("initial_value"))
return self.param("initial_value")
def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
if index != 0:
raise IndexError(f"Output index out of range (expected 0-0, got {index})")
if len(input_values) != 1:
raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected 1, got {len(input_values)})")
key = self.key(index, prefix)
value = self.param("initial_value")
if registers is not None:
value = registers.get(key, value)
registers[key] = self.truncate_inputs(input_values)[0]
if results is not None:
results[key] = value
return value
\ No newline at end of file
from test.fixtures.signal import signal, signals from test.fixtures.signal import signal, signals
from test.fixtures.operation_tree import * from test.fixtures.operation_tree import *
from test.fixtures.port import * from test.fixtures.port import *
from test.fixtures.signal_flow_graph import *
import pytest import pytest
from b_asic.core_operations import Addition, Constant
from b_asic.signal import Signal
import pytest import pytest
from b_asic import Addition, Constant, Signal
@pytest.fixture @pytest.fixture
def operation(): def operation():
return Constant(2) return Constant(2)
@pytest.fixture @pytest.fixture
def operation_tree(): def operation_tree():
"""Return a addition operation connected with 2 constants. """Valid addition operation connected with 2 constants.
---C---+ 2---+
+--A |
---C---+ v
add = 2 + 3 = 5
^
|
3---+
""" """
return Addition(Constant(2), Constant(3)) return Addition(Constant(2), Constant(3))
@pytest.fixture @pytest.fixture
def large_operation_tree(): def large_operation_tree():
"""Return an addition operation connected with a large operation tree with 2 other additions and 4 constants. """Valid addition operation connected with a large operation tree with 2 other additions and 4 constants.
---C---+ 2---+
+--A---+ |
---C---+ | v
+---A add---+
---C---+ | ^ |
+--A---+ | |
---C---+ 3---+ v
add = (2 + 3) + (4 + 5) = 14
4---+ ^
| |
v |
add---+
^
|
5---+
""" """
return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))) return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5)))
@pytest.fixture
def operation_graph_with_cycle():
"""Invalid addition operation connected with an operation graph containing a cycle.
+-+
| |
v |
add+---+
^ |
| v
7 add = (? + 7) + 6 = ?
^
|
6
"""
add1 = Addition(None, Constant(7))
add1.input(0).connect(add1)
return Addition(add1, Constant(6))
import pytest import pytest
from b_asic.port import InputPort, OutputPort
from b_asic import InputPort, OutputPort
@pytest.fixture @pytest.fixture
def input_port(): def input_port():
......
import pytest import pytest
from b_asic import Signal from b_asic import Signal
@pytest.fixture @pytest.fixture
def signal(): def signal():
"""Return a signal with no connections.""" """Return a signal with no connections."""
......
import pytest
from b_asic import SFG, Input, Output, Constant, Register
@pytest.fixture
def sfg_two_inputs_two_outputs():
"""Valid SFG with two inputs and two outputs.
. .
in1-------+ +--------->out1
. | | .
. v | .
. add1+--+ .
. ^ | .
. | v .
in2+------+ add2---->out2
| . ^ .
| . | .
+------------+ .
. .
out1 = in1 + in2
out2 = in1 + 2 * in2
"""
in1 = Input()
in2 = Input()
add1 = in1 + in2
add2 = add1 + in2
out1 = Output(add1)
out2 = Output(add2)
return SFG(inputs = [in1, in2], outputs = [out1, out2])
@pytest.fixture
def sfg_nested():
"""Valid SFG with two inputs and one output.
out1 = in1 + (in1 + in1 * in2) * (in1 + in2 * (in1 + in1 * in2))
"""
mac_in1 = Input()
mac_in2 = Input()
mac_in3 = Input()
mac_out1 = Output(mac_in1 + mac_in2 * mac_in3)
MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs = [mac_out1])
in1 = Input()
in2 = Input()
mac1 = MAC(in1, in1, in2)
mac2 = MAC(in1, in2, mac1)
mac3 = MAC(in1, mac1, mac2)
out1 = Output(mac3)
return SFG(inputs = [in1, in2], outputs = [out1])
@pytest.fixture
def sfg_delay():
"""Valid SFG with one input and one output.
out1 = in1'
"""
in1 = Input()
reg1 = Register(in1)
out1 = Output(reg1)
return SFG(inputs = [in1], outputs = [out1])
@pytest.fixture
def sfg_accumulator():
"""Valid SFG with two inputs and one output.
data_out = (data_in' + data_in) * (1 - reset)
"""
data_in = Input()
reset = Input()
reg = Register()
reg.input(0).connect((reg + data_in) * (1 - reset))
data_out = Output(reg)
return SFG(inputs = [data_in, reset], outputs = [data_out])
\ No newline at end of file
...@@ -2,11 +2,10 @@ ...@@ -2,11 +2,10 @@
B-ASIC test suite for the AbstractOperation class. B-ASIC test suite for the AbstractOperation class.
""" """
from b_asic.core_operations import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \
Multiplication, ConstantMultiplication, Division, ConstantDivision
import pytest import pytest
from b_asic import Addition, Subtraction, Multiplication, ConstantMultiplication, Division
def test_addition_overload(): def test_addition_overload():
"""Tests addition overloading for both operation and number argument.""" """Tests addition overloading for both operation and number argument."""
...@@ -14,15 +13,19 @@ def test_addition_overload(): ...@@ -14,15 +13,19 @@ def test_addition_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
add3 = add1 + add2 add3 = add1 + add2
assert isinstance(add3, Addition) assert isinstance(add3, Addition)
assert add3.input(0).signals == add1.output(0).signals assert add3.input(0).signals == add1.output(0).signals
assert add3.input(1).signals == add2.output(0).signals assert add3.input(1).signals == add2.output(0).signals
add4 = add3 + 5 add4 = add3 + 5
assert isinstance(add4, Addition)
assert isinstance(add4, ConstantAddition)
assert add4.input(0).signals == add3.output(0).signals assert add4.input(0).signals == add3.output(0).signals
assert add4.input(1).signals[0].source.operation.value == 5
add5 = 5 + add4
assert isinstance(add5, Addition)
assert add5.input(0).signals[0].source.operation.value == 5
assert add5.input(1).signals == add4.output(0).signals
def test_subtraction_overload(): def test_subtraction_overload():
...@@ -31,15 +34,19 @@ def test_subtraction_overload(): ...@@ -31,15 +34,19 @@ def test_subtraction_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
sub1 = add1 - add2 sub1 = add1 - add2
assert isinstance(sub1, Subtraction) assert isinstance(sub1, Subtraction)
assert sub1.input(0).signals == add1.output(0).signals assert sub1.input(0).signals == add1.output(0).signals
assert sub1.input(1).signals == add2.output(0).signals assert sub1.input(1).signals == add2.output(0).signals
sub2 = sub1 - 5 sub2 = sub1 - 5
assert isinstance(sub2, Subtraction)
assert isinstance(sub2, ConstantSubtraction)
assert sub2.input(0).signals == sub1.output(0).signals assert sub2.input(0).signals == sub1.output(0).signals
assert sub2.input(1).signals[0].source.operation.value == 5
sub3 = 5 - sub2
assert isinstance(sub3, Subtraction)
assert sub3.input(0).signals[0].source.operation.value == 5
assert sub3.input(1).signals == sub2.output(0).signals
def test_multiplication_overload(): def test_multiplication_overload():
...@@ -48,15 +55,19 @@ def test_multiplication_overload(): ...@@ -48,15 +55,19 @@ def test_multiplication_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
mul1 = add1 * add2 mul1 = add1 * add2
assert isinstance(mul1, Multiplication) assert isinstance(mul1, Multiplication)
assert mul1.input(0).signals == add1.output(0).signals assert mul1.input(0).signals == add1.output(0).signals
assert mul1.input(1).signals == add2.output(0).signals assert mul1.input(1).signals == add2.output(0).signals
mul2 = mul1 * 5 mul2 = mul1 * 5
assert isinstance(mul2, ConstantMultiplication) assert isinstance(mul2, ConstantMultiplication)
assert mul2.input(0).signals == mul1.output(0).signals assert mul2.input(0).signals == mul1.output(0).signals
assert mul2.value == 5
mul3 = 5 * mul2
assert isinstance(mul3, ConstantMultiplication)
assert mul3.input(0).signals == mul2.output(0).signals
assert mul3.value == 5
def test_division_overload(): def test_division_overload():
...@@ -65,13 +76,17 @@ def test_division_overload(): ...@@ -65,13 +76,17 @@ def test_division_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
div1 = add1 / add2 div1 = add1 / add2
assert isinstance(div1, Division) assert isinstance(div1, Division)
assert div1.input(0).signals == add1.output(0).signals assert div1.input(0).signals == add1.output(0).signals
assert div1.input(1).signals == add2.output(0).signals assert div1.input(1).signals == add2.output(0).signals
div2 = div1 / 5 div2 = div1 / 5
assert isinstance(div2, Division)
assert isinstance(div2, ConstantDivision)
assert div2.input(0).signals == div1.output(0).signals assert div2.input(0).signals == div1.output(0).signals
assert div2.input(1).signals[0].source.operation.value == 5
div3 = 5 / div2
assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
...@@ -2,313 +2,165 @@ ...@@ -2,313 +2,165 @@
B-ASIC test suite for the core operations. B-ASIC test suite for the core operations.
""" """
from b_asic.core_operations import Constant, Addition, Subtraction, \ from b_asic import \
Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
ConstantDivision, Butterfly
# Constant tests.
class TestConstant:
def test_constant_positive(self):
test_operation = Constant(3)
assert test_operation.evaluate_output(0, []) == 3
def test_constant(): def test_constant_negative(self):
constant_operation = Constant(3) test_operation = Constant(-3)
assert constant_operation.evaluate() == 3 assert test_operation.evaluate_output(0, []) == -3
def test_constant_complex(self):
test_operation = Constant(3+4j)
assert test_operation.evaluate_output(0, []) == 3+4j
def test_constant_negative():
constant_operation = Constant(-3)
assert constant_operation.evaluate() == -3
class TestAddition:
def test_addition_positive(self):
test_operation = Addition()
assert test_operation.evaluate_output(0, [3, 5]) == 8
def test_constant_complex(): def test_addition_negative(self):
constant_operation = Constant(3+4j) test_operation = Addition()
assert constant_operation.evaluate() == 3+4j assert test_operation.evaluate_output(0, [-3, -5]) == -8
# Addition tests. def test_addition_complex(self):
test_operation = Addition()
assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == 7+11j
def test_addition(): class TestSubtraction:
test_operation = Addition() def test_subtraction_positive(self):
constant_operation = Constant(3) test_operation = Subtraction()
constant_operation_2 = Constant(5) assert test_operation.evaluate_output(0, [5, 3]) == 2
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 8
def test_subtraction_negative(self):
test_operation = Subtraction()
assert test_operation.evaluate_output(0, [-5, -3]) == -2
def test_addition_negative(): def test_subtraction_complex(self):
test_operation = Addition() test_operation = Subtraction()
constant_operation = Constant(-3) assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == -1-1j
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -8
def test_addition_complex(): class TestMultiplication:
test_operation = Addition() def test_multiplication_positive(self):
constant_operation = Constant((3+5j)) test_operation = Multiplication()
constant_operation_2 = Constant((4+6j)) assert test_operation.evaluate_output(0, [5, 3]) == 15
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j)
# Subtraction tests. def test_multiplication_negative(self):
test_operation = Multiplication()
assert test_operation.evaluate_output(0, [-5, -3]) == 15
def test_multiplication_complex(self):
test_operation = Multiplication()
assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == -18+38j
def test_subtraction():
test_operation = Subtraction()
constant_operation = Constant(5)
constant_operation_2 = Constant(3)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 2
class TestDivision:
def test_division_positive(self):
test_operation = Division()
assert test_operation.evaluate_output(0, [30, 5]) == 6
def test_subtraction_negative(): def test_division_negative(self):
test_operation = Subtraction() test_operation = Division()
constant_operation = Constant(-5) assert test_operation.evaluate_output(0, [-30, -5]) == 6
constant_operation_2 = Constant(-3)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -2
def test_division_complex(self):
test_operation = Division()
assert test_operation.evaluate_output(0, [60+40j, 10+20j]) == 2.8-1.6j
def test_subtraction_complex():
test_operation = Subtraction()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j)
# Multiplication tests. class TestSquareRoot:
def test_squareroot_positive(self):
test_operation = SquareRoot()
assert test_operation.evaluate_output(0, [36]) == 6
def test_squareroot_negative(self):
test_operation = SquareRoot()
assert test_operation.evaluate_output(0, [-36]) == 6j
def test_multiplication(): def test_squareroot_complex(self):
test_operation = Multiplication() test_operation = SquareRoot()
constant_operation = Constant(5) assert test_operation.evaluate_output(0, [48+64j]) == 8+4j
constant_operation_2 = Constant(3)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_negative(): class TestComplexConjugate:
test_operation = Multiplication() def test_complexconjugate_positive(self):
constant_operation = Constant(-5) test_operation = ComplexConjugate()
constant_operation_2 = Constant(-3) assert test_operation.evaluate_output(0, [3+4j]) == 3-4j
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_test_complexconjugate_negative(self):
test_operation = ComplexConjugate()
assert test_operation.evaluate_output(0, [-3-4j]) == -3+4j
def test_multiplication_complex():
test_operation = Multiplication()
constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j)
# Division tests. class TestMax:
def test_max_positive(self):
test_operation = Max()
assert test_operation.evaluate_output(0, [30, 5]) == 30
def test_max_negative(self):
test_operation = Max()
assert test_operation.evaluate_output(0, [-30, -5]) == -5
def test_division():
test_operation = Division()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
class TestMin:
def test_min_positive(self):
test_operation = Min()
assert test_operation.evaluate_output(0, [30, 5]) == 5
def test_division_negative(): def test_min_negative(self):
test_operation = Division() test_operation = Min()
constant_operation = Constant(-30) assert test_operation.evaluate_output(0, [-30, -5]) == -30
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_complex(): class TestAbsolute:
test_operation = Division() def test_absolute_positive(self):
constant_operation = Constant((60+40j)) test_operation = Absolute()
constant_operation_2 = Constant((10+20j)) assert test_operation.evaluate_output(0, [30]) == 30
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j)
# SquareRoot tests. def test_absolute_negative(self):
test_operation = Absolute()
assert test_operation.evaluate_output(0, [-5]) == 5
def test_absolute_complex(self):
test_operation = Absolute()
assert test_operation.evaluate_output(0, [3+4j]) == 5.0
def test_squareroot():
test_operation = SquareRoot()
constant_operation = Constant(36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6
class TestConstantMultiplication:
def test_constantmultiplication_positive(self):
test_operation = ConstantMultiplication(5)
assert test_operation.evaluate_output(0, [20]) == 100
def test_squareroot_negative(): def test_constantmultiplication_negative(self):
test_operation = SquareRoot() test_operation = ConstantMultiplication(5)
constant_operation = Constant(-36) assert test_operation.evaluate_output(0, [-5]) == -25
assert test_operation.evaluate(constant_operation.evaluate()) == 6j
def test_constantmultiplication_complex(self):
test_operation = ConstantMultiplication(3+2j)
assert test_operation.evaluate_output(0, [3+4j]) == 1+18j
def test_squareroot_complex():
test_operation = SquareRoot()
constant_operation = Constant((48+64j))
assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j)
# ComplexConjugate tests. class TestButterfly:
def test_butterfly_positive(self):
test_operation = Butterfly()
assert test_operation.evaluate_output(0, [2, 3]) == 5
assert test_operation.evaluate_output(1, [2, 3]) == -1
def test_butterfly_negative(self):
test_operation = Butterfly()
assert test_operation.evaluate_output(0, [-2, -3]) == -5
assert test_operation.evaluate_output(1, [-2, -3]) == 1
def test_complexconjugate(): def test_buttefly_complex(self):
test_operation = ComplexConjugate() test_operation = Butterfly()
constant_operation = Constant(3+4j) assert test_operation.evaluate_output(0, [2+1j, 3-2j]) == 5-1j
assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) assert test_operation.evaluate_output(1, [2+1j, 3-2j]) == -1+3j
def test_test_complexconjugate_negative():
test_operation = ComplexConjugate()
constant_operation = Constant(-3-4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j)
# Max tests.
def test_max():
test_operation = Max()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 30
def test_max_negative():
test_operation = Max()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -5
# Min tests.
def test_min():
test_operation = Min()
constant_operation = Constant(30)
constant_operation_2 = Constant(5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 5
def test_min_negative():
test_operation = Min()
constant_operation = Constant(-30)
constant_operation_2 = Constant(-5)
assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -30
# Absolute tests.
def test_absolute():
test_operation = Absolute()
constant_operation = Constant(30)
assert test_operation.evaluate(constant_operation.evaluate()) == 30
def test_absolute_negative():
test_operation = Absolute()
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == 5
def test_absolute_complex():
test_operation = Absolute()
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == 5.0
# ConstantMultiplication tests.
def test_constantmultiplication():
test_operation = ConstantMultiplication(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 100
def test_constantmultiplication_negative():
test_operation = ConstantMultiplication(5)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -25
def test_constantmultiplication_complex():
test_operation = ConstantMultiplication(3+2j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j)
# ConstantAddition tests.
def test_constantaddition():
test_operation = ConstantAddition(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 25
def test_constantaddition_negative():
test_operation = ConstantAddition(4)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -1
def test_constantaddition_complex():
test_operation = ConstantAddition(3+2j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j)
# ConstantSubtraction tests.
def test_constantsubtraction():
test_operation = ConstantSubtraction(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 15
def test_constantsubtraction_negative():
test_operation = ConstantSubtraction(4)
constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -9
def test_constantsubtraction_complex():
test_operation = ConstantSubtraction(4+6j)
constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j)
# ConstantDivision tests.
def test_constantdivision():
test_operation = ConstantDivision(5)
constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 4
def test_constantdivision_negative():
test_operation = ConstantDivision(4)
constant_operation = Constant(-20)
assert test_operation.evaluate(constant_operation.evaluate()) == -5
def test_constantdivision_complex():
test_operation = ConstantDivision(2+2j)
constant_operation = Constant((10+10j))
assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j)
def test_butterfly():
test_operation = Butterfly()
assert list(test_operation.evaluate(2, 3)) == [5, -1]
def test_butterfly_negative():
test_operation = Butterfly()
assert list(test_operation.evaluate(-2, -3)) == [-5, 1]
def test_buttefly_complex():
test_operation = Butterfly()
assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j]
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
B-ASIC test suite for graph id generator. B-ASIC test suite for graph id generator.
""" """
from b_asic.signal_flow_graph import GraphIDGenerator, GraphID
import pytest import pytest
from b_asic import GraphIDGenerator, GraphID
@pytest.fixture @pytest.fixture
def graph_id_generator(): def graph_id_generator():
return GraphIDGenerator() return GraphIDGenerator()
......
...@@ -4,8 +4,7 @@ B-ASIC test suite for Inputport ...@@ -4,8 +4,7 @@ B-ASIC test suite for Inputport
import pytest import pytest
from b_asic import InputPort, OutputPort from b_asic import InputPort, OutputPort, Signal
from b_asic import Signal
@pytest.fixture @pytest.fixture
def inp_port(): def inp_port():
...@@ -74,28 +73,3 @@ def test_add_signal_then_disconnect(inp_port, s_w_source): ...@@ -74,28 +73,3 @@ def test_add_signal_then_disconnect(inp_port, s_w_source):
assert inp_port.signals == [] assert inp_port.signals == []
assert s_w_source.source.signals == [s_w_source] assert s_w_source.source.signals == [s_w_source]
assert s_w_source.destination is None assert s_w_source.destination is None
def test_set_value_length_pos_int(inp_port):
inp_port.value_length = 10
assert inp_port.value_length == 10
def test_set_value_length_zero(inp_port):
inp_port.value_length = 0
assert inp_port.value_length == 0
def test_set_value_length_neg_int(inp_port):
with pytest.raises(Exception):
inp_port.value_length = -10
def test_set_value_length_complex(inp_port):
with pytest.raises(Exception):
inp_port.value_length = (2+4j)
def test_set_value_length_float(inp_port):
with pytest.raises(Exception):
inp_port.value_length = 3.2
def test_set_value_length_pos_then_none(inp_port):
inp_port.value_length = 10
inp_port.value_length = None
assert inp_port.value_length is None
from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly
from b_asic.signal import Signal
from b_asic.port import InputPort, OutputPort
import pytest import pytest
from b_asic import Constant, Addition
class TestTraverse: class TestTraverse:
def test_traverse_single_tree(self, operation): def test_traverse_single_tree(self, operation):
...@@ -13,19 +10,16 @@ class TestTraverse: ...@@ -13,19 +10,16 @@ class TestTraverse:
def test_traverse_tree(self, operation_tree): def test_traverse_tree(self, operation_tree):
"""Traverse a basic addition tree with two constants.""" """Traverse a basic addition tree with two constants."""
assert len(list(operation_tree.traverse())) == 3 assert len(list(operation_tree.traverse())) == 5
def test_traverse_large_tree(self, large_operation_tree): def test_traverse_large_tree(self, large_operation_tree):
"""Traverse a larger tree.""" """Traverse a larger tree."""
assert len(list(large_operation_tree.traverse())) == 7 assert len(list(large_operation_tree.traverse())) == 13
def test_traverse_type(self, large_operation_tree): def test_traverse_type(self, large_operation_tree):
traverse = list(large_operation_tree.traverse()) result = list(large_operation_tree.traverse())
assert len( assert len(list(filter(lambda type_: isinstance(type_, Addition), result))) == 3
list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
assert len(
list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4
def test_traverse_loop(self, operation_tree): def test_traverse_loop(self, operation_graph_with_cycle):
# TODO: Construct a graph that contains a loop and make sure you can traverse it properly. assert len(list(operation_graph_with_cycle.traverse())) == 8
assert True \ No newline at end of file
""" """
B-ASIC test suite for OutputPort. B-ASIC test suite for OutputPort.
""" """
from b_asic import OutputPort, InputPort, Signal
import pytest import pytest
from b_asic import OutputPort, InputPort, Signal
@pytest.fixture @pytest.fixture
def output_port(): def output_port():
return OutputPort(None, 0) return OutputPort(None, 0)
...@@ -16,6 +18,7 @@ def input_port(): ...@@ -16,6 +18,7 @@ def input_port():
def list_of_input_ports(): def list_of_input_ports():
return [InputPort(None, i) for i in range(0, 3)] return [InputPort(None, i) for i in range(0, 3)]
class TestConnect: class TestConnect:
def test_multiple_ports(self, output_port, list_of_input_ports): def test_multiple_ports(self, output_port, list_of_input_ports):
"""Can multiple ports connect to an output port?""" """Can multiple ports connect to an output port?"""
......
...@@ -4,7 +4,7 @@ B-ASIC test suite for printing a SFG ...@@ -4,7 +4,7 @@ B-ASIC test suite for printing a SFG
from b_asic.signal_flow_graph import SFG from b_asic.signal_flow_graph import SFG
from b_asic.core_operations import Addition, Multiplication, Constant, ConstantAddition from b_asic.core_operations import Addition, Multiplication, Constant
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.special_operations import Input, Output from b_asic.special_operations import Input, Output
......