diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index a1a149d787f831405558b774993b1b0ef86fe0be..98523fbbc3ad6cf26c028e028429c60710578d33 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -30,12 +30,12 @@ class Constant(AbstractOperation): @property def value(self) -> Number: - """TODO: docstring""" + """Get the constant value of this operation.""" return self.param("value") @value.setter - def value(self, value: Number): - """TODO: docstring""" + def value(self, value: Number) -> None: + """Set the constant value of this operation.""" return self.set_param("value", value) @@ -103,36 +103,22 @@ class Division(AbstractOperation): return a / b -class SquareRoot(AbstractOperation): - """Unary square root operation. - TODO: More info. - """ - - def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - - @property - def type_name(self) -> TypeName: - return "sqrt" - - def evaluate(self, a): - return sqrt(complex(a)) - - -class ComplexConjugate(AbstractOperation): - """Unary complex conjugate operation. +class Min(AbstractOperation): + """Binary min operation. TODO: More info. """ - def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) @property def type_name(self) -> TypeName: - return "conj" + return "min" - def evaluate(self, a): - return conjugate(a) + def evaluate(self, a, b): + assert not isinstance(a, complex) and not isinstance(b, complex), \ + ("core_operations.Min does not support complex numbers.") + return a if a < b else b class Max(AbstractOperation): @@ -153,26 +139,8 @@ class Max(AbstractOperation): return a if a > b else b -class Min(AbstractOperation): - """Binary min operation. - TODO: More info. - """ - - def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) - - @property - def type_name(self) -> TypeName: - return "min" - - def evaluate(self, a, b): - assert not isinstance(a, complex) and not isinstance(b, complex), \ - ("core_operations.Min does not support complex numbers.") - return a if a < b else b - - -class Absolute(AbstractOperation): - """Unary absolute value operation. +class SquareRoot(AbstractOperation): + """Unary square root operation. TODO: More info. """ @@ -181,48 +149,46 @@ class Absolute(AbstractOperation): @property def type_name(self) -> TypeName: - return "abs" + return "sqrt" def evaluate(self, a): - return np_abs(a) + return sqrt(complex(a)) -class ConstantMultiplication(AbstractOperation): - """Unary constant multiplication operation. +class ComplexConjugate(AbstractOperation): + """Unary complex conjugate operation. TODO: More info. """ - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) @property def type_name(self) -> TypeName: - return "cmul" + return "conj" def evaluate(self, a): - return a * self.param("value") + return conjugate(a) -class ConstantAddition(AbstractOperation): - """Unary constant addition operation. +class Absolute(AbstractOperation): + """Unary absolute value operation. TODO: More info. """ - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) @property def type_name(self) -> TypeName: - return "cadd" + return "abs" def evaluate(self, a): - return a + self.param("value") + return np_abs(a) -class ConstantSubtraction(AbstractOperation): - """Unary constant subtraction operation. +class ConstantMultiplication(AbstractOperation): + """Unary constant multiplication operation. TODO: More info. """ @@ -232,27 +198,21 @@ class ConstantSubtraction(AbstractOperation): @property def type_name(self) -> TypeName: - return "csub" + return "cmul" def evaluate(self, a): - return a - self.param("value") - - -class ConstantDivision(AbstractOperation): - """Unary constant division operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) + return a * self.param("value") @property - def type_name(self) -> TypeName: - return "cdiv" + def value(self) -> Number: + """Get the constant value of this operation.""" + return self.param("value") + + @value.setter + def value(self, value: Number) -> None: + """Set the constant value of this operation.""" + return self.set_param("value", value) - def evaluate(self, a): - return a / self.param("value") class Butterfly(AbstractOperation): """Butterfly operation that returns two outputs. @@ -263,9 +223,9 @@ class Butterfly(AbstractOperation): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): super().__init__(input_count = 2, output_count = 2, name = name, input_sources = [src0, src1]) - def evaluate(self, a, b): - return a + b, a - b - @property def type_name(self) -> TypeName: return "bfly" + + def evaluate(self, a, b): + return a + b, a - b diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 52eba17c7e343842f636870d5d9a8fa694b713da..e37997016a3276f4dbde0394c44bf5c54ecbd51c 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -4,11 +4,15 @@ TODO: More info. """ from abc import ABC, abstractmethod -from copy import copy -from typing import NewType +from collections import deque +from copy import copy, deepcopy +from typing import NewType, Any, Dict, Mapping, Iterable, Generator + Name = NewType("Name", str) TypeName = NewType("TypeName", str) +GraphID = NewType("GraphID", str) +GraphIDNumber = NewType("GraphIDNumber", int) class GraphComponent(ABC): @@ -19,37 +23,87 @@ class GraphComponent(ABC): @property @abstractmethod def type_name(self) -> TypeName: - """Return the type name of the graph component""" + """Get the type name of this graph component""" raise NotImplementedError @property @abstractmethod def name(self) -> Name: - """Return the name of the graph component.""" + """Get the name of this graph component.""" raise NotImplementedError @name.setter @abstractmethod def name(self, name: Name) -> None: - """Set the name of the graph component to the entered name.""" + """Set the name of this graph component to the given name.""" + raise NotImplementedError + + @property + @abstractmethod + def graph_id(self) -> GraphID: + """Get the graph id of this graph component.""" raise NotImplementedError + @graph_id.setter @abstractmethod - def copy_unconnected(self) -> "GraphComponent": - """Get a copy of this graph component, except without any connected components.""" + def graph_id(self, graph_id: GraphID) -> None: + """Set the graph id of this graph component to the given id. + Note that this id will be ignored if this component is used to create a new graph, + and that a new local id will be generated for it instead.""" + raise NotImplementedError + + @property + @abstractmethod + def params(self) -> Mapping[str, Any]: + """Get a dictionary of all parameter values.""" + raise NotImplementedError + + @abstractmethod + def param(self, name: str) -> Any: + """Get the value of a parameter. + Returns None if the parameter is not defined. + """ + raise NotImplementedError + + @abstractmethod + def set_param(self, name: str, value: Any) -> None: + """Set the value of a parameter. + Adds the parameter if it is not already defined. + """ + raise NotImplementedError + + @abstractmethod + def copy_component(self, *args, **kwargs) -> "GraphComponent": + """Get a new instance of this graph component type with the same name, id and parameters.""" + raise NotImplementedError + + @property + @abstractmethod + def neighbors(self) -> Iterable["GraphComponent"]: + """Get all components that are directly connected to this operation.""" + raise NotImplementedError + + @abstractmethod + def traverse(self) -> Generator["GraphComponent", None, None]: + """Get a generator that recursively iterates through all components that are connected to this operation, + as well as the ones that they are connected to. + """ raise NotImplementedError class AbstractGraphComponent(GraphComponent): """Abstract Graph Component class which is a component of a signal flow graph. - TODO: More info. """ _name: Name + _graph_id: GraphID + _parameters: Dict[str, Any] def __init__(self, name: Name = ""): self._name = name + self._graph_id = "" + self._parameters = {} @property def name(self) -> Name: @@ -58,8 +112,41 @@ class AbstractGraphComponent(GraphComponent): @name.setter def name(self, name: Name) -> None: self._name = name + + @property + def graph_id(self) -> GraphID: + return self._graph_id + + @graph_id.setter + def graph_id(self, graph_id: GraphID) -> None: + self._graph_id = graph_id - def copy_unconnected(self) -> GraphComponent: - new_comp = self.__class__() - new_comp.name = copy(self.name) - return new_comp \ No newline at end of file + @property + def params(self) -> Mapping[str, Any]: + return self._parameters.copy() + + def param(self, name: str) -> Any: + return self._parameters.get(name) + + def set_param(self, name: str, value: Any) -> None: + self._parameters[name] = value + + def copy_component(self, *args, **kwargs) -> GraphComponent: + new_component = self.__class__(*args, **kwargs) + new_component.name = copy(self.name) + new_component.graph_id = copy(self.graph_id) + for name, value in self.params.items(): + new_component.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member + return new_component + + def traverse(self) -> Generator[GraphComponent, None, None]: + # Breadth first search. + visited = {self} + fontier = deque([self]) + while fontier: + component = fontier.popleft() + yield component + for neighbor in component.neighbors: + if neighbor not in visited: + visited.add(neighbor) + fontier.append(neighbor) \ No newline at end of file diff --git a/b_asic/operation.py b/b_asic/operation.py index ed327127991db7a8d3c641c519ebf44257224fc7..a0d0f48a1f7429ce0d393ad4e93ef24c84914f7b 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -6,14 +6,19 @@ TODO: More info. import collections from abc import abstractmethod -from copy import deepcopy from numbers import Number -from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union -from collections import deque +from typing import NewType, List, Sequence, Iterable, Mapping, MutableMapping, Optional, Any, Set, Union +from math import trunc from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort +from b_asic.signal import Signal +ResultKey = NewType("ResultKey", str) +ResultMap = Mapping[ResultKey, Optional[Number]] +MutableResultMap = MutableMapping[ResultKey, Optional[Number]] +RegisterMap = Mapping[ResultKey, Number] +MutableRegisterMap = MutableMapping[ResultKey, Number] class Operation(GraphComponent, SignalSourceProvider): """Operation interface. @@ -21,18 +26,30 @@ class Operation(GraphComponent, SignalSourceProvider): """ @abstractmethod - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": """Overloads the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. If other is a number then - returns a ConstantAddition operation object instead. + object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod - def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + """Overloads the addition operator to make it return a new Addition operation + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": """Overloads the subtraction operator to make it return a new Subtraction operation - object that is connected to the self and other objects. If other is a number then - returns a ConstantSubtraction operation object instead. + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + """Overloads the subtraction operator to make it return a new Subtraction operation + object that is connected to the self and other objects. """ raise NotImplementedError @@ -45,23 +62,25 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": - """Overloads the division operator to make it return a new Division operation + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + """Overloads the multiplication operator to make it return a new Multiplication operation object that is connected to the self and other objects. If other is a number then - returns a ConstantDivision operation object instead. + returns a ConstantMultiplication operation object instead. """ raise NotImplementedError - @property @abstractmethod - def inputs(self) -> List[InputPort]: - """Get a list of all input ports.""" + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. + """ raise NotImplementedError - @property @abstractmethod - def outputs(self) -> List[OutputPort]: - """Get a list of all output ports.""" + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. + """ raise NotImplementedError @property @@ -77,64 +96,87 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def input(self, i: int) -> InputPort: - """Get the input port at index i.""" + def input(self, index: int) -> InputPort: + """Get the input port at the given index.""" raise NotImplementedError @abstractmethod - def output(self, i: int) -> OutputPort: - """Get the output port at index i.""" + def output(self, index: int) -> OutputPort: + """Get the output port at the given index.""" raise NotImplementedError + @property @abstractmethod - def params(self) -> Dict[str, Optional[Any]]: - """Get a dictionary of all parameter values.""" + def inputs(self) -> Sequence[InputPort]: + """Get all input ports.""" raise NotImplementedError + @property @abstractmethod - def param(self, name: str) -> Optional[Any]: - """Get the value of a parameter. - Returns None if the parameter is not defined. + def outputs(self) -> Sequence[OutputPort]: + """Get all output ports.""" + raise NotImplementedError + + @property + @abstractmethod + def input_signals(self) -> Iterable[Signal]: + """Get all the signals that are connected to this operation's input ports, + in no particular order. """ raise NotImplementedError + @property @abstractmethod - def set_param(self, name: str, value: Any) -> None: - """Set the value of a parameter. - Adds the parameter if it is not already defined. + def output_signals(self) -> Iterable[Signal]: + """Get all the signals that are connected to this operation's output ports, + in no particular order. """ raise NotImplementedError @abstractmethod - def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: - """Evaluate the output at index i of this operation with the given input values. - The returned sequence contains results corresponding to each output of this operation, - where a value of None means it was not evaluated. - The value at index i is guaranteed to have been evaluated, while the others may or may not - have been evaluated depending on what is the most efficient. - For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1). + def key(self, index: int, prefix: str = "") -> ResultKey: + """Get the key used to access the result of a certain output of this operation + from the results parameter passed to current_output(s) or evaluate_output(s). """ raise NotImplementedError @abstractmethod - def split(self) -> Iterable["Operation"]: - """Split the operation into multiple operations. - If splitting is not possible, this may return a list containing only the operation itself. + def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]: + """Get the current output at the given index of this operation, if available. + The registers parameter will be used for lookup. + The prefix parameter will be used as a prefix for the key string when looking for registers. + See also: current_outputs, evaluate_output, evaluate_outputs. + """ + raise NotImplementedError + + @abstractmethod + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + """Evaluate the output at the given index of this operation with the given input values. + The results parameter will be used to store any results (including intermediate results) for caching. + The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values. + The prefix parameter will be used as a prefix for the key string when storing results/registers. + See also: evaluate_outputs, current_output, current_outputs. """ raise NotImplementedError - @property @abstractmethod - def neighbors(self) -> Iterable["Operation"]: - """Return all operations that are connected by signals to this operation. - If no neighbors are found, this returns an empty list. + def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]: + """Get all current outputs of this operation, if available. + See current_output for more information. """ raise NotImplementedError @abstractmethod - def traverse(self) -> Generator["Operation", None, None]: - """Get a generator that recursively iterates through all operations that are connected by signals to this operation, - as well as the ones that they are connected to. + def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: + """Evaluate all outputs of this operation given the input values. + See evaluate_output for more information. + """ + raise NotImplementedError + + @abstractmethod + def split(self) -> Iterable["Operation"]: + """Split the operation into multiple operations. + If splitting is not possible, this may return a list containing only the operation itself. """ raise NotImplementedError @@ -146,116 +188,142 @@ class AbstractOperation(Operation, AbstractGraphComponent): _input_ports: List[InputPort] _output_ports: List[OutputPort] - _parameters: Dict[str, Optional[Any]] def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__(name) - self._input_ports = [] - self._output_ports = [] - self._parameters = {} - - # Allocate input ports. - for i in range(input_count): - self._input_ports.append(InputPort(self, i)) - - # Allocate output ports. - for i in range(output_count): - self._output_ports.append(OutputPort(self, i)) + self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports. + self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports. # Connect given input sources, if any. if input_sources is not None: source_count = len(input_sources) if source_count != input_count: - raise ValueError( - f"Operation expected {input_count} input sources but only got {source_count}") + raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ - """Evaluate the operation and generate a list of output values given a - list of input values. - """ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + """Evaluate the operation and generate a list of output values given a list of input values.""" raise NotImplementedError - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, ConstantAddition - - if isinstance(src, Number): - return ConstantAddition(src, self) - return Addition(self, src) - - def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": - # Import here to avoid circular imports. - from b_asic.core_operations import Subtraction, ConstantSubtraction - - if isinstance(src, Number): - return ConstantSubtraction(src, self) - return Subtraction(self, src) + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + return Addition(self, Constant(src) if isinstance(src, Number) else src) + + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + return Addition(Constant(src) if isinstance(src, Number) else src, self) + + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + return Subtraction(self, Constant(src) if isinstance(src, Number) else src) + + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + return Subtraction(Constant(src) if isinstance(src, Number) else src, self) def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": - # Import here to avoid circular imports. - from b_asic.core_operations import Multiplication, ConstantMultiplication + from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src) + + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self) + + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + return Division(self, Constant(src) if isinstance(src, Number) else src) + + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + return Division(Constant(src) if isinstance(src, Number) else src, self) + + @property + def input_count(self) -> int: + return len(self._input_ports) - if isinstance(src, Number): - return ConstantMultiplication(src, self) - return Multiplication(self, src) + @property + def output_count(self) -> int: + return len(self._output_ports) - def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": - # Import here to avoid circular imports. - from b_asic.core_operations import Division, ConstantDivision + def input(self, index: int) -> InputPort: + return self._input_ports[index] - if isinstance(src, Number): - return ConstantDivision(src, self) - return Division(self, src) + def output(self, index: int) -> OutputPort: + return self._output_ports[index] @property - def inputs(self) -> List[InputPort]: - return self._input_ports.copy() + def inputs(self) -> Sequence[InputPort]: + return self._input_ports @property - def outputs(self) -> List[OutputPort]: - return self._output_ports.copy() + def outputs(self) -> Sequence[OutputPort]: + return self._output_ports @property - def input_count(self) -> int: - return len(self._input_ports) + def input_signals(self) -> Iterable[Signal]: + result = [] + for p in self.inputs: + for s in p.signals: + result.append(s) + return result @property - def output_count(self) -> int: - return len(self._output_ports) + def output_signals(self) -> Iterable[Signal]: + result = [] + for p in self.outputs: + for s in p.signals: + result.append(s) + return result + + def key(self, index: int, prefix: str = "") -> ResultKey: + key = prefix + if self.output_count != 1: + if key: + key += "." + key += str(index) + elif not key: + key = str(index) + return key + + def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]: + return None + + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + if index < 0 or index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") + if len(input_values) != self.input_count: + raise ValueError(f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})") + if results is None: + results = {} + if registers is None: + registers = {} + + values = self.evaluate(*self.truncate_inputs(input_values)) + if isinstance(values, collections.abc.Sequence): + if len(values) != self.output_count: + raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})") + elif isinstance(values, Number): + if self.output_count != 1: + raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)") + values = (values,) + else: + raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})") - def input(self, i: int) -> InputPort: - return self._input_ports[i] + if self.output_count == 1: + results[self.key(index, prefix)] = values[index] + else: + for i in range(self.output_count): + results[self.key(i, prefix)] = values[i] + return values[index] - def output(self, i: int) -> OutputPort: - return self._output_ports[i] + def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]: + return [self.current_output(i, registers, prefix) for i in range(self.output_count)] - @property - def params(self) -> Dict[str, Optional[Any]]: - return self._parameters.copy() - - def param(self, name: str) -> Optional[Any]: - return self._parameters.get(name) - - def set_param(self, name: str, value: Any) -> None: - self._parameters[name] = value - - def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: - result = self.evaluate(*input_values) - if isinstance(result, collections.Sequence): - if len(result) != self.output_count: - raise RuntimeError( - "Operation evaluated to incorrect number of outputs") - return result - if isinstance(result, Number): - if self.output_count != 1: - raise RuntimeError( - "Operation evaluated to incorrect number of outputs") - return [result] - raise RuntimeError("Operation evaluated to invalid type") + def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: + return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)] def split(self) -> Iterable[Operation]: # Import here to avoid circular imports. @@ -273,27 +341,8 @@ class AbstractOperation(Operation, AbstractGraphComponent): return [self] @property - def neighbors(self) -> Iterable[Operation]: - neighbors = [] - for port in self._input_ports: - for signal in port.signals: - neighbors.append(signal.source.operation) - for port in self._output_ports: - for signal in port.signals: - neighbors.append(signal.destination.operation) - return neighbors - - def traverse(self) -> Generator[Operation, None, None]: - # Breadth first search. - visited = {self} - queue = deque([self]) - while queue: - operation = queue.popleft() - yield operation - for n_operation in operation.neighbors: - if n_operation not in visited: - visited.add(n_operation) - queue.append(n_operation) + def neighbors(self) -> Iterable[GraphComponent]: + return list(self.input_signals) + list(self.output_signals) @property def source(self) -> OutputPort: @@ -302,10 +351,26 @@ class AbstractOperation(Operation, AbstractGraphComponent): raise TypeError( f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) - - def copy_unconnected(self) -> GraphComponent: - new_comp: AbstractOperation = super().copy_unconnected() - for name, value in self.params.items(): - new_comp.set_param(name, deepcopy( - value)) # pylint: disable=no-member - return new_comp + + def truncate_input(self, index: int, value: Number, bits: int) -> Number: + """Truncate the value to be used as input at the given index to a certain bit length.""" + n = value + if not isinstance(n, int): + n = trunc(value) + return n & ((2 ** bits) - 1) + + def truncate_inputs(self, input_values: Sequence[Number]) -> Sequence[Number]: + """Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input.""" + args = [] + for i, input_port in enumerate(self.inputs): + if input_port.signal_count >= 1: + bits = input_port.signals[0].bits + if bits is None: + args.append(input_values[i]) + else: + if isinstance(input_values[i], complex): + raise TypeError("Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}") + args.append(self.truncate_input(i, input_values[i], bits)) + else: + args.append(input_values[i]) + return args diff --git a/b_asic/port.py b/b_asic/port.py index 103d076af2702e7e565067f7568bb6035d24a2c8..e8c007cbf077f9f40df0d53fc08001e6436f0093 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -108,12 +108,10 @@ class InputPort(AbstractPort): """ _source_signal: Optional[Signal] - _value_length: Optional[int] def __init__(self, operation: "Operation", index: int): super().__init__(operation, index) self._source_signal = None - self._value_length = None @property def signal_count(self) -> int: @@ -153,18 +151,6 @@ class InputPort(AbstractPort): # self._source_signal is set by the signal constructor. return Signal(source=src.source, destination=self, name=name) - @property - def value_length(self) -> Optional[int]: - """Get the number of bits that this port should truncate received values to.""" - return self._value_length - - @value_length.setter - def value_length(self, bits: Optional[int]) -> None: - """Set the number of bits that this port should truncate received values to.""" - assert bits is None or (isinstance( - bits, int) and bits >= 0), "Value length must be non-negative." - self._value_length = bits - class OutputPort(AbstractPort, SignalSourceProvider): """Output port. diff --git a/b_asic/signal.py b/b_asic/signal.py index 67e1d0f908ba57f5d355e77794993587343e63cf..d322f161b1d9c1195c2f8e5beb40f0ea7244a25b 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -1,9 +1,9 @@ """@package docstring B-ASIC Signal Module. """ -from typing import Optional, TYPE_CHECKING +from typing import Optional, Iterable, TYPE_CHECKING -from b_asic.graph_component import AbstractGraphComponent, TypeName, Name +from b_asic.graph_component import GraphComponent, AbstractGraphComponent, TypeName, Name if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort @@ -15,8 +15,7 @@ class Signal(AbstractGraphComponent): _source: Optional["OutputPort"] _destination: Optional["InputPort"] - def __init__(self, source: Optional["OutputPort"] = None, \ - destination: Optional["InputPort"] = None, name: Name = ""): + def __init__(self, source: Optional["OutputPort"] = None, destination: Optional["InputPort"] = None, bits: Optional[int] = None, name: Name = ""): super().__init__(name) self._source = None self._destination = None @@ -24,7 +23,16 @@ class Signal(AbstractGraphComponent): self.set_source(source) if destination is not None: self.set_destination(destination) + self.set_param("bits", bits) + @property + def type_name(self) -> TypeName: + return "s" + + @property + def neighbors(self) -> Iterable[GraphComponent]: + return [p.operation for p in [self.source, self.destination] if p is not None] + @property def source(self) -> Optional["OutputPort"]: """Return the source OutputPort of the signal.""" @@ -63,10 +71,6 @@ class Signal(AbstractGraphComponent): if self not in dest.signals: dest.add_signal(self) - @property - def type_name(self) -> TypeName: - return "s" - def remove_source(self) -> None: """Disconnect the source OutputPort of the signal. If the source port still is connected to this signal then also disconnect the source port.""" @@ -88,3 +92,16 @@ class Signal(AbstractGraphComponent): """Returns true if the signal is missing either a source or a destination, else false.""" return self._source is None or self._destination is None + + @property + def bits(self) -> Optional[int]: + """Get the number of bits that this operations using this signal as an input should truncate received values to. + None = unlimited.""" + return self.param("bits") + + @bits.setter + def bits(self, bits: Optional[int]) -> None: + """Set the number of bits that operations using this signal as an input should truncate received values to. + None = unlimited.""" + assert bits is None or (isinstance(bits, int) and bits >= 0), "Bits must be non-negative." + self.set_param("bits", bits) \ No newline at end of file diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 2f2a024053d41b2f15a4eb18ceb239c5832ff95d..e8e7af01ab93fdba948d9ff7ec19078b3b71dee6 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,21 +3,17 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set +from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Set from numbers import Number from collections import defaultdict, deque from b_asic.port import SignalSourceProvider, OutputPort -from b_asic.operation import Operation, AbstractOperation +from b_asic.operation import Operation, AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap from b_asic.signal import Signal -from b_asic.graph_component import GraphComponent, Name, TypeName +from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output -GraphID = NewType("GraphID", str) -GraphIDNumber = NewType("GraphIDNumber", int) - - class GraphIDGenerator: """A class that generates Graph IDs for objects.""" @@ -27,10 +23,15 @@ class GraphIDGenerator: self._next_id_number = defaultdict(lambda: id_number_offset) def next_id(self, type_name: TypeName) -> GraphID: - """Return the next graph id for a certain graph id type.""" + """Get the next graph id for a certain graph id type.""" self._next_id_number[type_name] += 1 return type_name + str(self._next_id_number[type_name]) + @property + def id_number_offset(self) -> GraphIDNumber: + """Get the graph id number offset of this generator.""" + return self._next_id_number.default_factory() # pylint: disable=not-callable + class SFG(AbstractOperation): """Signal flow graph. @@ -39,363 +40,348 @@ class SFG(AbstractOperation): _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] + _components_ordered: List[GraphComponent] + _operations_ordered: List[Operation] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] - _original_components_added: Set[GraphComponent] - _original_input_signals: Dict[Signal, int] - _original_output_signals: Dict[Signal, int] + _original_components_to_new: Set[GraphComponent] + _original_input_signals_to_indices: Dict[Signal, int] + _original_output_signals_to_indices: Dict[Signal, int] - def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], - inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], - id_number_offset: GraphIDNumber = 0, name: Name = "", + def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, \ + inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, \ + id_number_offset: GraphIDNumber = 0, name: Name = "", \ input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): - super().__init__( - input_count=len(input_signals) + len(inputs), - output_count=len(output_signals) + len(outputs), - name=name, - input_sources=input_sources) + input_signal_count = 0 if input_signals is None else len(input_signals) + input_operation_count = 0 if inputs is None else len(inputs) + output_signal_count = 0 if output_signals is None else len(output_signals) + output_operation_count = 0 if outputs is None else len(outputs) + super().__init__(input_count = input_signal_count + input_operation_count, + output_count = output_signal_count + output_operation_count, + name = name, input_sources = input_sources) self._components_by_id = dict() self._components_by_name = defaultdict(list) - self._components_in_dfs_order = [] + self._components_ordered = [] + self._operations_ordered = [] self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] - # Maps original components to new copied components - self._added_components_mapping = {} - self._original_input_signals_indexes = {} - self._original_output_signals_indexes = {} - self._id_number_offset = id_number_offset + self._original_components_to_new = {} + self._original_input_signals_to_indices = {} + self._original_output_signals_to_indices = {} # Setup input signals. - for input_index, sig in enumerate(input_signals): - assert sig not in self._added_components_mapping, "Duplicate input signals sent to SFG construcctor." - - new_input_op = self._add_component_copy_unconnected(Input()) - new_sig = self._add_component_copy_unconnected(sig) - new_sig.set_source(new_input_op.output(0)) - - self._input_operations.append(new_input_op) - self._original_input_signals_indexes[sig] = input_index - - # Setup input operations, starting from indexes ater input signals. - for input_index, input_op in enumerate(inputs, len(input_signals)): - assert input_op not in self._added_components_mapping, "Duplicate input operations sent to SFG constructor." - new_input_op = self._add_component_copy_unconnected(input_op) - - for sig in input_op.output(0).signals: - assert sig not in self._added_components_mapping, "Duplicate input signals connected to input ports sent to SFG construcctor." - new_sig = self._add_component_copy_unconnected(sig) - new_sig.set_source(new_input_op.output(0)) - - self._original_input_signals_indexes[sig] = input_index - - self._input_operations.append(new_input_op) + if input_signals is not None: + for input_index, signal in enumerate(input_signals): + assert signal not in self._original_components_to_new, "Duplicate input signals supplied to SFG construcctor." + new_input_op = self._add_component_unconnected_copy(Input()) + new_signal = self._add_component_unconnected_copy(signal) + new_signal.set_source(new_input_op.output(0)) + self._input_operations.append(new_input_op) + self._original_input_signals_to_indices[signal] = input_index + + # Setup input operations, starting from indices ater input signals. + if inputs is not None: + for input_index, input_op in enumerate(inputs, input_signal_count): + assert input_op not in self._original_components_to_new, "Duplicate input operations supplied to SFG constructor." + new_input_op = self._add_component_unconnected_copy(input_op) + for signal in input_op.output(0).signals: + assert signal not in self._original_components_to_new, "Duplicate input signals connected to input ports supplied to SFG construcctor." + new_signal = self._add_component_unconnected_copy(signal) + new_signal.set_source(new_input_op.output(0)) + self._original_input_signals_to_indices[signal] = input_index + + self._input_operations.append(new_input_op) # Setup output signals. - for output_ind, sig in enumerate(output_signals): - new_out = self._add_component_copy_unconnected(Output()) - if sig in self._added_components_mapping: - # Signal already added when setting up inputs - new_sig = self._added_components_mapping[sig] - new_sig.set_destination(new_out.input(0)) - else: - # New signal has to be created - new_sig = self._add_component_copy_unconnected(sig) - new_sig.set_destination(new_out.input(0)) - - self._output_operations.append(new_out) - self._original_output_signals_indexes[sig] = output_ind - - # Setup output operations, starting from indexes after output signals. - for output_ind, output_op in enumerate(outputs, len(output_signals)): - assert output_op not in self._added_components_mapping, "Duplicate output operations sent to SFG constructor." - - new_out = self._add_component_copy_unconnected(output_op) - for sig in output_op.input(0).signals: - if sig in self._added_components_mapping: - # Signal already added when setting up inputs - new_sig = self._added_components_mapping[sig] - new_sig.set_destination(new_out.input(0)) + if output_signals is not None: + for output_index, signal in enumerate(output_signals): + new_output_op = self._add_component_unconnected_copy(Output()) + if signal in self._original_components_to_new: + # Signal was already added when setting up inputs. + new_signal = self._original_components_to_new[signal] + new_signal.set_destination(new_output_op.input(0)) else: - # New signal has to be created - new_sig = self._add_component_copy_unconnected(sig) - new_sig.set_destination(new_out.input(0)) - - self._original_output_signals_indexes[sig] = output_ind - - self._output_operations.append(new_out) + # New signal has to be created. + new_signal = self._add_component_unconnected_copy(signal) + new_signal.set_destination(new_output_op.input(0)) + + self._output_operations.append(new_output_op) + self._original_output_signals_to_indices[signal] = output_index + + # Setup output operations, starting from indices after output signals. + if outputs is not None: + for output_index, output_op in enumerate(outputs, output_signal_count): + assert output_op not in self._original_components_to_new, "Duplicate output operations supplied to SFG constructor." + new_output_op = self._add_component_unconnected_copy(output_op) + for signal in output_op.input(0).signals: + new_signal = None + if signal in self._original_components_to_new: + # Signal was already added when setting up inputs. + new_signal = self._original_components_to_new[signal] + else: + # New signal has to be created. + new_signal = self._add_component_unconnected_copy(signal) + + new_signal.set_destination(new_output_op.input(0)) + self._original_output_signals_to_indices[signal] = output_index + + self._output_operations.append(new_output_op) output_operations_set = set(self._output_operations) # Search the graph inwards from each input signal. - for sig, input_index in self._original_input_signals_indexes.items(): + for signal, input_index in self._original_input_signals_to_indices.items(): # Check if already added destination. - new_sig = self._added_components_mapping[sig] - if new_sig.destination is not None and new_sig.destination.operation in output_operations_set: - # Add directly connected input to output to dfs order list - self._components_in_dfs_order.extend([ - new_sig.source.operation, new_sig, new_sig.destination.operation]) - elif sig.destination is None: - raise ValueError( - f"Input signal #{input_index} is missing destination in SFG") - elif sig.destination.operation not in self._added_components_mapping: - self._copy_structure_from_operation_dfs( - sig.destination.operation) + new_signal = self._original_components_to_new[signal] + if new_signal.destination is None: + if signal.destination is None: + raise ValueError(f"Input signal #{input_index} is missing destination in SFG") + if signal.destination.operation not in self._original_components_to_new: + self._add_operation_connected_tree_copy(signal.destination.operation) + elif new_signal.destination.operation in output_operations_set: + # Add directly connected input to output to ordered list. + self._components_ordered.extend([new_signal.source.operation, new_signal, new_signal.destination.operation]) + self._operations_ordered.extend([new_signal.source.operation, new_signal.destination.operation]) # Search the graph inwards from each output signal. - for sig, output_index in self._original_output_signals_indexes.items(): + for signal, output_index in self._original_output_signals_to_indices.items(): # Check if already added source. - mew_sig = self._added_components_mapping[sig] - if new_sig.source is None: - if sig.source is None: - raise ValueError( - f"Output signal #{output_index} is missing source in SFG") - if sig.source.operation not in self._added_components_mapping: - self._copy_structure_from_operation_dfs( - sig.source.operation) + new_signal = self._original_components_to_new[signal] + if new_signal.source is None: + if signal.source is None: + raise ValueError(f"Output signal #{output_index} is missing source in SFG") + if signal.source.operation not in self._original_components_to_new: + self._add_operation_connected_tree_copy(signal.source.operation) + + def __str__(self) -> str: + """Get a string representation of this SFG.""" + output_string = "" + for component in self._components_ordered: + if isinstance(component, Operation): + for key, value in self._components_by_id.items(): + if value is component: + output_string += "id: " + key + ", name: " + + if component.name != None: + output_string += component.name + ", " + else: + output_string += "-, " + + if component.type_name is "c": + output_string += "value: " + str(component.value) + ", input: [" + else: + output_string += "input: [" + + counter_input = 0 + for input in component.inputs: + counter_input += 1 + for signal in input.signals: + for key, value in self._components_by_id.items(): + if value is signal: + output_string += key + ", " + + if counter_input > 0: + output_string = output_string[:-2] + output_string += "], output: [" + counter_output = 0 + for output in component.outputs: + counter_output += 1 + for signal in output.signals: + for key, value in self._components_by_id.items(): + if value is signal: + output_string += key + ", " + if counter_output > 0: + output_string = output_string[:-2] + output_string += "]\n" + + return output_string + + def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG": + """Get a new independent SFG instance that is identical to this SFG except without any of its external connections.""" + return SFG(inputs = self._input_operations, outputs = self._output_operations, + id_number_offset = self.id_number_offset, name = name, input_sources = src if src else None) @property def type_name(self) -> TypeName: return "sfg" def evaluate(self, *args): - if len(args) != self.input_count: - raise ValueError( - "Wrong number of inputs supplied to SFG for evaluation") - for arg, op in zip(args, self._input_operations): - op.value = arg - - result = [] - for op in self._output_operations: - result.append(self._evaluate_source(op.input(0).signals[0].source)) - + result = self.evaluate_outputs(args, {}, {}, "") n = len(result) return None if n == 0 else result[0] if n == 1 else result - def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: - assert i >= 0 and i < self.output_count, "Output index out of range" - result = [None] * self.output_count - result[i] = self._evaluate_source( - self._output_operations[i].input(0).signals[0].source) - return result + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + if index < 0 or index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") + if len(input_values) != self.input_count: + raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected {self.input_count}, got {len(input_values)})") + if results is None: + results = {} + if registers is None: + registers = {} + + # Set the values of our input operations to the given input values. + for op, arg in zip(self._input_operations, self.truncate_inputs(input_values)): + op.value = arg + + value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix) + results[self.key(index, prefix)] = value + return value def split(self) -> Iterable[Operation]: - return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values()) + return self.operations + + def copy_component(self, *args, **kwargs) -> GraphComponent: + return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, + id_number_offset = self.id_number_offset, name = self.name) + + @property + def id_number_offset(self) -> GraphIDNumber: + """Get the graph id number offset of the graph id generator for this SFG.""" + return self._graph_id_generator.id_number_offset @property def components(self) -> Iterable[GraphComponent]: - """Get all components of this graph in the dfs-traversal order.""" - return self._components_in_dfs_order + """Get all components of this graph in depth-first order.""" + return self._components_ordered + + @property + def operations(self) -> Iterable[Operation]: + """Get all operations of this graph in depth-first order.""" + return self._operations_ordered def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: - """Find a graph object based on the entered Graph ID and return it. If no graph - object with the entered ID was found then return None. + """Find the graph component with the specified ID. + Returns None if the component was not found. Keyword arguments: - graph_id: Graph ID of the wanted object. + graph_id: Graph ID of the desired component(s) """ return self._components_by_id.get(graph_id, None) - def find_by_name(self, name: Name) -> List[GraphComponent]: - """Find all graph objects that have the entered name and return them - in a list. If no graph object with the entered name was found then return an - empty list. + def find_by_name(self, name: Name) -> Sequence[GraphComponent]: + """Find all graph components with the specified name. + Returns an empty sequence if no components were found. Keyword arguments: - name: Name of the wanted object. + name: Name of the desired component(s) """ return self._components_by_name.get(name, []) - def deep_copy(self) -> "SFG": - """Returns a deep copy of self.""" - copy = SFG(inputs=self._input_operations, outputs=self._output_operations, - id_number_offset=self._id_number_offset, name=super().name) - - return copy - - def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: - - assert original_comp not in self._added_components_mapping, "Tried to add duplicate SFG component" - - new_comp = original_comp.copy_unconnected() - - self._added_components_mapping[original_comp] = new_comp - self._components_by_id[self._graph_id_generator.next_id( - new_comp.type_name)] = new_comp - self._components_by_name[new_comp.name].append(new_comp) - - return new_comp - - def _copy_structure_from_operation_dfs(self, start_op: Operation): + def _add_component_unconnected_copy(self, original_component: GraphComponent) -> GraphComponent: + assert original_component not in self._original_components_to_new, "Tried to add duplicate SFG component" + new_component = original_component.copy_component() + self._original_components_to_new[original_component] = new_component + new_id = self._graph_id_generator.next_id(new_component.type_name) + new_component.graph_id = new_id + self._components_by_id[new_id] = new_component + self._components_by_name[new_component.name].append(new_component) + return new_component + + def _add_operation_connected_tree_copy(self, start_op: Operation) -> None: op_stack = deque([start_op]) - while op_stack: original_op = op_stack.pop() - # Add or get the new copy of the operation.. + # Add or get the new copy of the operation. new_op = None - if original_op not in self._added_components_mapping: - new_op = self._add_component_copy_unconnected(original_op) - self._components_in_dfs_order.append(new_op) + if original_op not in self._original_components_to_new: + new_op = self._add_component_unconnected_copy(original_op) + self._components_ordered.append(new_op) + self._operations_ordered.append(new_op) else: - new_op = self._added_components_mapping[original_op] + new_op = self._original_components_to_new[original_op] - # Connect input ports to new signals + # Connect input ports to new signals. for original_input_port in original_op.inputs: if original_input_port.signal_count < 1: raise ValueError("Unconnected input port in SFG") for original_signal in original_input_port.signals: + # Check if the signal is one of the SFG's input signals. + if original_signal in self._original_input_signals_to_indices: + # New signal already created during first step of constructor. + new_signal = self._original_components_to_new[original_signal] + new_signal.set_destination(new_op.input(original_input_port.index)) + self._components_ordered.extend([new_signal, new_signal.source.operation]) + self._operations_ordered.append(new_signal.source.operation) - # Check if the signal is one of the SFG's input signals - if original_signal in self._original_input_signals_indexes: - - # New signal already created during first step of constructor - new_signal = self._added_components_mapping[ - original_signal] - new_signal.set_destination( - new_op.input(original_input_port.index)) - - self._components_in_dfs_order.extend( - [new_signal, new_signal.source.operation]) - - # Check if the signal has not been added before - elif original_signal not in self._added_components_mapping: + # Check if the signal has not been added before. + elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") - - new_signal = self._add_component_copy_unconnected( - original_signal) - new_signal.set_destination( - new_op.input(original_input_port.index)) - - self._components_in_dfs_order.append(new_signal) + raise ValueError("Dangling signal without source in SFG") + + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_destination(new_op.input(original_input_port.index)) + self._components_ordered.append(new_signal) original_connected_op = original_signal.source.operation - # Check if connected Operation has been added before - if original_connected_op in self._added_components_mapping: - # Set source to the already added operations port - new_signal.set_source( - self._added_components_mapping[original_connected_op].output( - original_signal.source.index)) + # Check if connected Operation has been added before. + if original_connected_op in self._original_components_to_new: + # Set source to the already added operations port. + new_signal.set_source(self._original_components_to_new[original_connected_op].output(original_signal.source.index)) else: - # Create new operation, set signal source to it - new_connected_op = self._add_component_copy_unconnected( - original_connected_op) - new_signal.set_source(new_connected_op.output( - original_signal.source.index)) + # Create new operation, set signal source to it. + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_source(new_connected_op.output(original_signal.source.index)) + self._components_ordered.append(new_connected_op) + self._operations_ordered.append(new_connected_op) - self._components_in_dfs_order.append( - new_connected_op) - - # Add connected operation to queue of operations to visit + # Add connected operation to queue of operations to visit. op_stack.append(original_connected_op) - # Connect output ports + # Connect output ports. for original_output_port in original_op.outputs: - for original_signal in original_output_port.signals: # Check if the signal is one of the SFG's output signals. - if original_signal in self._original_output_signals_indexes: - + if original_signal in self._original_output_signals_to_indices: # New signal already created during first step of constructor. - new_signal = self._added_components_mapping[ - original_signal] - new_signal.set_source( - new_op.output(original_output_port.index)) - - self._components_in_dfs_order.extend( - [new_signal, new_signal.destination.operation]) + new_signal = self._original_components_to_new[original_signal] + new_signal.set_source(new_op.output(original_output_port.index)) + self._components_ordered.extend([new_signal, new_signal.destination.operation]) + self._operations_ordered.append(new_signal.destination.operation) # Check if signal has not been added before. - elif original_signal not in self._added_components_mapping: + elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") - - new_signal = self._add_component_copy_unconnected( - original_signal) - new_signal.set_source( - new_op.output(original_output_port.index)) + raise ValueError("Dangling signal without source in SFG") - self._components_in_dfs_order.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_source(new_op.output(original_output_port.index)) + self._components_ordered.append(new_signal) original_connected_op = original_signal.destination.operation # Check if connected operation has been added. - if original_connected_op in self._added_components_mapping: - # Set destination to the already connected operations port - new_signal.set_destination( - self._added_components_mapping[original_connected_op].input( - original_signal.destination.index)) - + if original_connected_op in self._original_components_to_new: + # Set destination to the already connected operations port. + new_signal.set_destination(self._original_components_to_new[original_connected_op].input(original_signal.destination.index)) else: # Create new operation, set destination to it. - new_connected_op = self._add_component_copy_unconnected( - original_connected_op) - new_signal.set_destination(new_connected_op.input( - original_signal.destination.index)) - - self._components_in_dfs_order.append( - new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + self._components_ordered.append(new_connected_op) + self._operations_ordered.append(new_connected_op) - # Add connected operation to the queue of operations to visist + # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) - def _evaluate_source(self, src: OutputPort) -> Number: - input_values = [] - for input_port in src.operation.inputs: - input_src = input_port.signals[0].source - input_values.append(self._evaluate_source(input_src)) - return src.operation.evaluate_output(src.index, input_values) - - - def __str__(self): - """Prints operations, inputs and outputs in a SFG - """ - - output_string = "" - - for comp in self._components_in_dfs_order: - if isinstance(comp, Operation): - for key, value in self._components_by_id.items(): - if value is comp: - output_string += "id: " + key + ", name: " - - if comp.name != None: - output_string += comp.name + ", " - else: - output_string += "-, " - - if comp.type_name is "c": - output_string += "value: " + str(comp.value) + ", input: [" - else: - output_string += "input: [" - - counter_input = 0 - for input in comp.inputs: - counter_input += 1 - for signal in input.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - - if counter_input > 0: - output_string = output_string[:-2] - output_string += "], output: [" - counter_output = 0 - for output in comp.outputs: - counter_output += 1 - for signal in output.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - if counter_output > 0: - output_string = output_string[:-2] - output_string += "]\n" - - return output_string - - - + def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: + src_prefix = prefix + if src_prefix: + src_prefix += "." + src_prefix += src.operation.graph_id + + key = src.operation.key(src.index, src_prefix) + if key in results: + value = results[key] + if value is None: + raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") + return value + + results[key] = src.operation.current_output(src.index, registers, src_prefix) + input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] + value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) + results[key] = value + return value diff --git a/b_asic/simulation.py b/b_asic/simulation.py index a2ce11b3263d517cba79c92093e594d712c5b8f3..9d0d154fa899923ee28cf444512262fb85c73a3a 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -3,41 +3,111 @@ B-ASIC Simulation Module. TODO: More info. """ +from collections import defaultdict from numbers import Number -from typing import List, Dict +from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping, Union, Optional +from b_asic.operation import ResultKey, ResultMap +from b_asic.signal_flow_graph import SFG -class OperationState: - """Simulation state of an operation. + +InputProvider = Union[Number, Sequence[Number], Callable[[int], Number]] + + +class Simulation: + """Simulation. TODO: More info. """ - output_values: List[Number] - iteration: int + _sfg: SFG + _results: DefaultDict[int, Dict[str, Number]] + _registers: Dict[str, Number] + _iteration: int + _input_functions: Sequence[Callable[[int], Number]] + _current_input_values: Sequence[Number] + _latest_output_values: Sequence[Number] + _save_results: bool - def __init__(self): - self.output_values = [] - self.iteration = 0 + def __init__(self, sfg: SFG, input_providers: Optional[Sequence[Optional[InputProvider]]] = None, save_results: bool = False): + self._sfg = sfg + self._results = defaultdict(dict) + self._registers = {} + self._iteration = 0 + self._input_functions = [lambda _: 0 for _ in range(self._sfg.input_count)] + self._current_input_values = [0 for _ in range(self._sfg.input_count)] + self._latest_output_values = [0 for _ in range(self._sfg.output_count)] + self._save_results = save_results + if input_providers is not None: + self.set_inputs(input_providers) + def set_input(self, index: int, input_provider: InputProvider) -> None: + """Set the input function used to get values for the specific input at the given index to the internal SFG.""" + if index < 0 or index >= len(self._input_functions): + raise IndexError(f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})") + if callable(input_provider): + self._input_functions[index] = input_provider + elif isinstance(input_provider, Number): + self._input_functions[index] = lambda _: input_provider + else: + self._input_functions[index] = lambda n: input_provider[n] -class SimulationState: - """Simulation state. - TODO: More info. - """ + def set_inputs(self, input_providers: Sequence[Optional[InputProvider]]) -> None: + """Set the input functions used to get values for the inputs to the internal SFG.""" + if len(input_providers) != self._sfg.input_count: + raise ValueError(f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})") + self._input_functions = [None for _ in range(self._sfg.input_count)] + for index, input_provider in enumerate(input_providers): + if input_provider is not None: + self.set_input(index, input_provider) + + @property + def save_results(self) -> bool: + """Get the flag that determines if the results of .""" + return self._save_results + + @save_results.setter + def save_results(self, save_results) -> None: + self._save_results = save_results + + def run(self) -> Sequence[Number]: + """Run one iteration of the simulation and return the resulting output values.""" + return self.run_for(1) + + def run_until(self, iteration: int) -> Sequence[Number]: + """Run the simulation until its iteration is greater than or equal to the given iteration + and return the resulting output values. + """ + while self._iteration < iteration: + self._current_input_values = [self._input_functions[i](self._iteration) for i in range(self._sfg.input_count)] + self._latest_output_values = self._sfg.evaluate_outputs(self._current_input_values, self._results[self._iteration], self._registers) + if not self._save_results: + del self._results[self.iteration] + self._iteration += 1 + return self._latest_output_values + + def run_for(self, iterations: int) -> Sequence[Number]: + """Run a given number of iterations of the simulation and return the resulting output values.""" + return self.run_until(self._iteration + iterations) + + @property + def iteration(self) -> int: + """Get the current iteration number of the simulation.""" + return self._iteration + + @property + def results(self) -> Mapping[int, ResultMap]: + """Get a mapping of all results, including intermediate values, calculated for each iteration up until now. + The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values. + Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}} + """ + return self._results + + def clear_results(self) -> None: + """Clear all results that were saved until now.""" + self._results.clear() - operation_states: Dict[int, OperationState] - iteration: int - - def __init__(self): - op_state = OperationState() - self.operation_states = {1: op_state} - self.iteration = 0 - - # @property - # #def iteration(self): - # return self.iteration - # @iteration.setter - # def iteration(self, new_iteration: int): - # self.iteration = new_iteration - # - # TODO: More stuff + def clear_state(self) -> None: + """Clear all current state of the simulation, except for the results and iteration.""" + self._registers.clear() + self._current_input_values = [0 for _ in range(self._sfg.input_count)] + self._latest_output_values = [0 for _ in range(self._sfg.output_count)] \ No newline at end of file diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 465c0086d0120b10e27f769a216874b2e08dd53c..96d341b9cac01cc2d260544b4d2501d68c0808c0 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -4,9 +4,9 @@ TODO: More info. """ from numbers import Number -from typing import Optional +from typing import Optional, Sequence -from b_asic.operation import AbstractOperation +from b_asic.operation import AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap from b_asic.graph_component import Name, TypeName from b_asic.port import SignalSourceProvider @@ -29,12 +29,12 @@ class Input(AbstractOperation): @property def value(self) -> Number: - """TODO: docstring""" + """Get the current value of this input.""" return self.param("value") @value.setter - def value(self, value: Number): - """TODO: docstring""" + def value(self, value: Number) -> None: + """Set the current value of this input.""" self.set_param("value", value) @@ -44,11 +44,48 @@ class Output(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 0, name = name, input_sources=[src0]) + super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0]) @property def type_name(self) -> TypeName: return "out" - def evaluate(self): - return None \ No newline at end of file + def evaluate(self, _): + return None + + +class Register(AbstractOperation): + """Unit delay operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, initial_value: Number = 0, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + self.set_param("initial_value", initial_value) + + @property + def type_name(self) -> TypeName: + return "reg" + + def evaluate(self, a): + return self.param("initial_value") + + def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]: + if registers is not None: + return registers.get(self.key(index, prefix), self.param("initial_value")) + return self.param("initial_value") + + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + if index != 0: + raise IndexError(f"Output index out of range (expected 0-0, got {index})") + if len(input_values) != 1: + raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected 1, got {len(input_values)})") + + key = self.key(index, prefix) + value = self.param("initial_value") + if registers is not None: + value = registers.get(key, value) + registers[key] = self.truncate_inputs(input_values)[0] + if results is not None: + results[key] = value + return value \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 64f39843c53a4369781a269fd7fc30ad9aa1d255..48b49489424817e1439f6b2b6eb3d7cd63b29a75 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,5 @@ from test.fixtures.signal import signal, signals from test.fixtures.operation_tree import * from test.fixtures.port import * +from test.fixtures.signal_flow_graph import * import pytest diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index 94a1e42f724fdf7f14dbd13debaccc850fbbf552..fc8008fa4098ca488e23766f5ff7d05711300685 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -1,30 +1,60 @@ -from b_asic.core_operations import Addition, Constant -from b_asic.signal import Signal - import pytest +from b_asic import Addition, Constant, Signal + + @pytest.fixture def operation(): return Constant(2) @pytest.fixture def operation_tree(): - """Return a addition operation connected with 2 constants. - ---C---+ - +--A - ---C---+ + """Valid addition operation connected with 2 constants. + 2---+ + | + v + add = 2 + 3 = 5 + ^ + | + 3---+ """ return Addition(Constant(2), Constant(3)) @pytest.fixture def large_operation_tree(): - """Return an addition operation connected with a large operation tree with 2 other additions and 4 constants. - ---C---+ - +--A---+ - ---C---+ | - +---A - ---C---+ | - +--A---+ - ---C---+ + """Valid addition operation connected with a large operation tree with 2 other additions and 4 constants. + 2---+ + | + v + add---+ + ^ | + | | + 3---+ v + add = (2 + 3) + (4 + 5) = 14 + 4---+ ^ + | | + v | + add---+ + ^ + | + 5---+ """ return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))) + +@pytest.fixture +def operation_graph_with_cycle(): + """Invalid addition operation connected with an operation graph containing a cycle. + +-+ + | | + v | + add+---+ + ^ | + | v + 7 add = (? + 7) + 6 = ? + ^ + | + 6 + """ + add1 = Addition(None, Constant(7)) + add1.input(0).connect(add1) + return Addition(add1, Constant(6)) diff --git a/test/fixtures/port.py b/test/fixtures/port.py index 63632ecdb3a9d81a7f27759cd7166af3163c9e94..fa528b8d9437e60b99c1ec426f317eb97b0164f2 100644 --- a/test/fixtures/port.py +++ b/test/fixtures/port.py @@ -1,5 +1,7 @@ import pytest -from b_asic.port import InputPort, OutputPort + +from b_asic import InputPort, OutputPort + @pytest.fixture def input_port(): diff --git a/test/fixtures/signal.py b/test/fixtures/signal.py index 0c5692feb3203f37876e48df0ab7f2caa69c4d45..4dba99e24bce16aba67cba58057b3cde76f0923d 100644 --- a/test/fixtures/signal.py +++ b/test/fixtures/signal.py @@ -1,6 +1,8 @@ import pytest + from b_asic import Signal + @pytest.fixture def signal(): """Return a signal with no connections.""" diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6c554d1340478dad25a11655f0542bf6fba1d1 --- /dev/null +++ b/test/fixtures/signal_flow_graph.py @@ -0,0 +1,71 @@ +import pytest + +from b_asic import SFG, Input, Output, Constant, Register + + +@pytest.fixture +def sfg_two_inputs_two_outputs(): + """Valid SFG with two inputs and two outputs. + . . + in1-------+ +--------->out1 + . | | . + . v | . + . add1+--+ . + . ^ | . + . | v . + in2+------+ add2---->out2 + | . ^ . + | . | . + +------------+ . + . . + out1 = in1 + in2 + out2 = in1 + 2 * in2 + """ + in1 = Input() + in2 = Input() + add1 = in1 + in2 + add2 = add1 + in2 + out1 = Output(add1) + out2 = Output(add2) + return SFG(inputs = [in1, in2], outputs = [out1, out2]) + +@pytest.fixture +def sfg_nested(): + """Valid SFG with two inputs and one output. + out1 = in1 + (in1 + in1 * in2) * (in1 + in2 * (in1 + in1 * in2)) + """ + mac_in1 = Input() + mac_in2 = Input() + mac_in3 = Input() + mac_out1 = Output(mac_in1 + mac_in2 * mac_in3) + MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs = [mac_out1]) + + in1 = Input() + in2 = Input() + mac1 = MAC(in1, in1, in2) + mac2 = MAC(in1, in2, mac1) + mac3 = MAC(in1, mac1, mac2) + out1 = Output(mac3) + return SFG(inputs = [in1, in2], outputs = [out1]) + +@pytest.fixture +def sfg_delay(): + """Valid SFG with one input and one output. + out1 = in1' + """ + in1 = Input() + reg1 = Register(in1) + out1 = Output(reg1) + return SFG(inputs = [in1], outputs = [out1]) + +@pytest.fixture +def sfg_accumulator(): + """Valid SFG with two inputs and one output. + data_out = (data_in' + data_in) * (1 - reset) + """ + data_in = Input() + reset = Input() + reg = Register() + reg.input(0).connect((reg + data_in) * (1 - reset)) + data_out = Output(reg) + return SFG(inputs = [data_in, reset], outputs = [data_out]) \ No newline at end of file diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index 626a2dc3e5e26fb76d9266dcdd31940681df5c6e..5423ecdf08c420df5dccc6393c3ad6637961172b 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -2,11 +2,10 @@ B-ASIC test suite for the AbstractOperation class. """ -from b_asic.core_operations import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ - Multiplication, ConstantMultiplication, Division, ConstantDivision - import pytest +from b_asic import Addition, Subtraction, Multiplication, ConstantMultiplication, Division + def test_addition_overload(): """Tests addition overloading for both operation and number argument.""" @@ -14,15 +13,19 @@ def test_addition_overload(): add2 = Addition(None, None, "add2") add3 = add1 + add2 - assert isinstance(add3, Addition) assert add3.input(0).signals == add1.output(0).signals assert add3.input(1).signals == add2.output(0).signals add4 = add3 + 5 - - assert isinstance(add4, ConstantAddition) + assert isinstance(add4, Addition) assert add4.input(0).signals == add3.output(0).signals + assert add4.input(1).signals[0].source.operation.value == 5 + + add5 = 5 + add4 + assert isinstance(add5, Addition) + assert add5.input(0).signals[0].source.operation.value == 5 + assert add5.input(1).signals == add4.output(0).signals def test_subtraction_overload(): @@ -31,15 +34,19 @@ def test_subtraction_overload(): add2 = Addition(None, None, "add2") sub1 = add1 - add2 - assert isinstance(sub1, Subtraction) assert sub1.input(0).signals == add1.output(0).signals assert sub1.input(1).signals == add2.output(0).signals sub2 = sub1 - 5 - - assert isinstance(sub2, ConstantSubtraction) + assert isinstance(sub2, Subtraction) assert sub2.input(0).signals == sub1.output(0).signals + assert sub2.input(1).signals[0].source.operation.value == 5 + + sub3 = 5 - sub2 + assert isinstance(sub3, Subtraction) + assert sub3.input(0).signals[0].source.operation.value == 5 + assert sub3.input(1).signals == sub2.output(0).signals def test_multiplication_overload(): @@ -48,15 +55,19 @@ def test_multiplication_overload(): add2 = Addition(None, None, "add2") mul1 = add1 * add2 - assert isinstance(mul1, Multiplication) assert mul1.input(0).signals == add1.output(0).signals assert mul1.input(1).signals == add2.output(0).signals mul2 = mul1 * 5 - assert isinstance(mul2, ConstantMultiplication) assert mul2.input(0).signals == mul1.output(0).signals + assert mul2.value == 5 + + mul3 = 5 * mul2 + assert isinstance(mul3, ConstantMultiplication) + assert mul3.input(0).signals == mul2.output(0).signals + assert mul3.value == 5 def test_division_overload(): @@ -65,13 +76,17 @@ def test_division_overload(): add2 = Addition(None, None, "add2") div1 = add1 / add2 - assert isinstance(div1, Division) assert div1.input(0).signals == add1.output(0).signals assert div1.input(1).signals == add2.output(0).signals div2 = div1 / 5 - - assert isinstance(div2, ConstantDivision) + assert isinstance(div2, Division) assert div2.input(0).signals == div1.output(0).signals + assert div2.input(1).signals[0].source.operation.value == 5 + + div3 = 5 / div2 + assert isinstance(div3, Division) + assert div3.input(0).signals[0].source.operation.value == 5 + assert div3.input(1).signals == div2.output(0).signals diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 854ccf85f447e430af303dc9a45c8946ac8d7828..4d0039b558e81c5cd74f151f93f0bc0194a702d5 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -2,313 +2,165 @@ B-ASIC test suite for the core operations. """ -from b_asic.core_operations import Constant, Addition, Subtraction, \ - Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ - Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ - ConstantDivision, Butterfly +from b_asic import \ + Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ + SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly -# Constant tests. +class TestConstant: + def test_constant_positive(self): + test_operation = Constant(3) + assert test_operation.evaluate_output(0, []) == 3 -def test_constant(): - constant_operation = Constant(3) - assert constant_operation.evaluate() == 3 + def test_constant_negative(self): + test_operation = Constant(-3) + assert test_operation.evaluate_output(0, []) == -3 + def test_constant_complex(self): + test_operation = Constant(3+4j) + assert test_operation.evaluate_output(0, []) == 3+4j -def test_constant_negative(): - constant_operation = Constant(-3) - assert constant_operation.evaluate() == -3 +class TestAddition: + def test_addition_positive(self): + test_operation = Addition() + assert test_operation.evaluate_output(0, [3, 5]) == 8 -def test_constant_complex(): - constant_operation = Constant(3+4j) - assert constant_operation.evaluate() == 3+4j + def test_addition_negative(self): + test_operation = Addition() + assert test_operation.evaluate_output(0, [-3, -5]) == -8 -# Addition tests. + def test_addition_complex(self): + test_operation = Addition() + assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == 7+11j -def test_addition(): - test_operation = Addition() - constant_operation = Constant(3) - constant_operation_2 = Constant(5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 +class TestSubtraction: + def test_subtraction_positive(self): + test_operation = Subtraction() + assert test_operation.evaluate_output(0, [5, 3]) == 2 + def test_subtraction_negative(self): + test_operation = Subtraction() + assert test_operation.evaluate_output(0, [-5, -3]) == -2 -def test_addition_negative(): - test_operation = Addition() - constant_operation = Constant(-3) - constant_operation_2 = Constant(-5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 + def test_subtraction_complex(self): + test_operation = Subtraction() + assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == -1-1j -def test_addition_complex(): - test_operation = Addition() - constant_operation = Constant((3+5j)) - constant_operation_2 = Constant((4+6j)) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) +class TestMultiplication: + def test_multiplication_positive(self): + test_operation = Multiplication() + assert test_operation.evaluate_output(0, [5, 3]) == 15 -# Subtraction tests. + def test_multiplication_negative(self): + test_operation = Multiplication() + assert test_operation.evaluate_output(0, [-5, -3]) == 15 + def test_multiplication_complex(self): + test_operation = Multiplication() + assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == -18+38j -def test_subtraction(): - test_operation = Subtraction() - constant_operation = Constant(5) - constant_operation_2 = Constant(3) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 +class TestDivision: + def test_division_positive(self): + test_operation = Division() + assert test_operation.evaluate_output(0, [30, 5]) == 6 -def test_subtraction_negative(): - test_operation = Subtraction() - constant_operation = Constant(-5) - constant_operation_2 = Constant(-3) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 + def test_division_negative(self): + test_operation = Division() + assert test_operation.evaluate_output(0, [-30, -5]) == 6 + def test_division_complex(self): + test_operation = Division() + assert test_operation.evaluate_output(0, [60+40j, 10+20j]) == 2.8-1.6j -def test_subtraction_complex(): - test_operation = Subtraction() - constant_operation = Constant((3+5j)) - constant_operation_2 = Constant((4+6j)) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) -# Multiplication tests. +class TestSquareRoot: + def test_squareroot_positive(self): + test_operation = SquareRoot() + assert test_operation.evaluate_output(0, [36]) == 6 + def test_squareroot_negative(self): + test_operation = SquareRoot() + assert test_operation.evaluate_output(0, [-36]) == 6j -def test_multiplication(): - test_operation = Multiplication() - constant_operation = Constant(5) - constant_operation_2 = Constant(3) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + def test_squareroot_complex(self): + test_operation = SquareRoot() + assert test_operation.evaluate_output(0, [48+64j]) == 8+4j -def test_multiplication_negative(): - test_operation = Multiplication() - constant_operation = Constant(-5) - constant_operation_2 = Constant(-3) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 +class TestComplexConjugate: + def test_complexconjugate_positive(self): + test_operation = ComplexConjugate() + assert test_operation.evaluate_output(0, [3+4j]) == 3-4j + def test_test_complexconjugate_negative(self): + test_operation = ComplexConjugate() + assert test_operation.evaluate_output(0, [-3-4j]) == -3+4j -def test_multiplication_complex(): - test_operation = Multiplication() - constant_operation = Constant((3+5j)) - constant_operation_2 = Constant((4+6j)) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) -# Division tests. +class TestMax: + def test_max_positive(self): + test_operation = Max() + assert test_operation.evaluate_output(0, [30, 5]) == 30 + def test_max_negative(self): + test_operation = Max() + assert test_operation.evaluate_output(0, [-30, -5]) == -5 -def test_division(): - test_operation = Division() - constant_operation = Constant(30) - constant_operation_2 = Constant(5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 +class TestMin: + def test_min_positive(self): + test_operation = Min() + assert test_operation.evaluate_output(0, [30, 5]) == 5 -def test_division_negative(): - test_operation = Division() - constant_operation = Constant(-30) - constant_operation_2 = Constant(-5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + def test_min_negative(self): + test_operation = Min() + assert test_operation.evaluate_output(0, [-30, -5]) == -30 -def test_division_complex(): - test_operation = Division() - constant_operation = Constant((60+40j)) - constant_operation_2 = Constant((10+20j)) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) +class TestAbsolute: + def test_absolute_positive(self): + test_operation = Absolute() + assert test_operation.evaluate_output(0, [30]) == 30 -# SquareRoot tests. + def test_absolute_negative(self): + test_operation = Absolute() + assert test_operation.evaluate_output(0, [-5]) == 5 + def test_absolute_complex(self): + test_operation = Absolute() + assert test_operation.evaluate_output(0, [3+4j]) == 5.0 -def test_squareroot(): - test_operation = SquareRoot() - constant_operation = Constant(36) - assert test_operation.evaluate(constant_operation.evaluate()) == 6 +class TestConstantMultiplication: + def test_constantmultiplication_positive(self): + test_operation = ConstantMultiplication(5) + assert test_operation.evaluate_output(0, [20]) == 100 -def test_squareroot_negative(): - test_operation = SquareRoot() - constant_operation = Constant(-36) - assert test_operation.evaluate(constant_operation.evaluate()) == 6j + def test_constantmultiplication_negative(self): + test_operation = ConstantMultiplication(5) + assert test_operation.evaluate_output(0, [-5]) == -25 + def test_constantmultiplication_complex(self): + test_operation = ConstantMultiplication(3+2j) + assert test_operation.evaluate_output(0, [3+4j]) == 1+18j -def test_squareroot_complex(): - test_operation = SquareRoot() - constant_operation = Constant((48+64j)) - assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) -# ComplexConjugate tests. +class TestButterfly: + def test_butterfly_positive(self): + test_operation = Butterfly() + assert test_operation.evaluate_output(0, [2, 3]) == 5 + assert test_operation.evaluate_output(1, [2, 3]) == -1 + def test_butterfly_negative(self): + test_operation = Butterfly() + assert test_operation.evaluate_output(0, [-2, -3]) == -5 + assert test_operation.evaluate_output(1, [-2, -3]) == 1 -def test_complexconjugate(): - test_operation = ComplexConjugate() - constant_operation = Constant(3+4j) - assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) - - -def test_test_complexconjugate_negative(): - test_operation = ComplexConjugate() - constant_operation = Constant(-3-4j) - assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) - -# Max tests. - - -def test_max(): - test_operation = Max() - constant_operation = Constant(30) - constant_operation_2 = Constant(5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 - - -def test_max_negative(): - test_operation = Max() - constant_operation = Constant(-30) - constant_operation_2 = Constant(-5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 - -# Min tests. - - -def test_min(): - test_operation = Min() - constant_operation = Constant(30) - constant_operation_2 = Constant(5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 - - -def test_min_negative(): - test_operation = Min() - constant_operation = Constant(-30) - constant_operation_2 = Constant(-5) - assert test_operation.evaluate( - constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 - -# Absolute tests. - - -def test_absolute(): - test_operation = Absolute() - constant_operation = Constant(30) - assert test_operation.evaluate(constant_operation.evaluate()) == 30 - - -def test_absolute_negative(): - test_operation = Absolute() - constant_operation = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate()) == 5 - - -def test_absolute_complex(): - test_operation = Absolute() - constant_operation = Constant((3+4j)) - assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 - -# ConstantMultiplication tests. - - -def test_constantmultiplication(): - test_operation = ConstantMultiplication(5) - constant_operation = Constant(20) - assert test_operation.evaluate(constant_operation.evaluate()) == 100 - - -def test_constantmultiplication_negative(): - test_operation = ConstantMultiplication(5) - constant_operation = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate()) == -25 - - -def test_constantmultiplication_complex(): - test_operation = ConstantMultiplication(3+2j) - constant_operation = Constant((3+4j)) - assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) - -# ConstantAddition tests. - - -def test_constantaddition(): - test_operation = ConstantAddition(5) - constant_operation = Constant(20) - assert test_operation.evaluate(constant_operation.evaluate()) == 25 - - -def test_constantaddition_negative(): - test_operation = ConstantAddition(4) - constant_operation = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate()) == -1 - - -def test_constantaddition_complex(): - test_operation = ConstantAddition(3+2j) - constant_operation = Constant((3+4j)) - assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) - -# ConstantSubtraction tests. - - -def test_constantsubtraction(): - test_operation = ConstantSubtraction(5) - constant_operation = Constant(20) - assert test_operation.evaluate(constant_operation.evaluate()) == 15 - - -def test_constantsubtraction_negative(): - test_operation = ConstantSubtraction(4) - constant_operation = Constant(-5) - assert test_operation.evaluate(constant_operation.evaluate()) == -9 - - -def test_constantsubtraction_complex(): - test_operation = ConstantSubtraction(4+6j) - constant_operation = Constant((3+4j)) - assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) - -# ConstantDivision tests. - - -def test_constantdivision(): - test_operation = ConstantDivision(5) - constant_operation = Constant(20) - assert test_operation.evaluate(constant_operation.evaluate()) == 4 - - -def test_constantdivision_negative(): - test_operation = ConstantDivision(4) - constant_operation = Constant(-20) - assert test_operation.evaluate(constant_operation.evaluate()) == -5 - - -def test_constantdivision_complex(): - test_operation = ConstantDivision(2+2j) - constant_operation = Constant((10+10j)) - assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) - - -def test_butterfly(): - test_operation = Butterfly() - assert list(test_operation.evaluate(2, 3)) == [5, -1] - - -def test_butterfly_negative(): - test_operation = Butterfly() - assert list(test_operation.evaluate(-2, -3)) == [-5, 1] - - -def test_buttefly_complex(): - test_operation = Butterfly() - assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j] + def test_buttefly_complex(self): + test_operation = Butterfly() + assert test_operation.evaluate_output(0, [2+1j, 3-2j]) == 5-1j + assert test_operation.evaluate_output(1, [2+1j, 3-2j]) == -1+3j diff --git a/test/test_graph_id_generator.py b/test/test_graph_id_generator.py index b8e0cdebb7f1cc32297bacff89314244dda7cd6f..72c923b63b6af74296cca86cd432da7e488d55b6 100644 --- a/test/test_graph_id_generator.py +++ b/test/test_graph_id_generator.py @@ -2,9 +2,10 @@ B-ASIC test suite for graph id generator. """ -from b_asic.signal_flow_graph import GraphIDGenerator, GraphID import pytest +from b_asic import GraphIDGenerator, GraphID + @pytest.fixture def graph_id_generator(): return GraphIDGenerator() diff --git a/test/test_inputport.py b/test/test_inputport.py index b43bf8e3d11eb3286c087c6a8bbb0b46956e51fb..85f892217c7e0f766417f6cc2e6d066d48d8a537 100644 --- a/test/test_inputport.py +++ b/test/test_inputport.py @@ -4,8 +4,7 @@ B-ASIC test suite for Inputport import pytest -from b_asic import InputPort, OutputPort -from b_asic import Signal +from b_asic import InputPort, OutputPort, Signal @pytest.fixture def inp_port(): @@ -74,28 +73,3 @@ def test_add_signal_then_disconnect(inp_port, s_w_source): assert inp_port.signals == [] assert s_w_source.source.signals == [s_w_source] assert s_w_source.destination is None - -def test_set_value_length_pos_int(inp_port): - inp_port.value_length = 10 - assert inp_port.value_length == 10 - -def test_set_value_length_zero(inp_port): - inp_port.value_length = 0 - assert inp_port.value_length == 0 - -def test_set_value_length_neg_int(inp_port): - with pytest.raises(Exception): - inp_port.value_length = -10 - -def test_set_value_length_complex(inp_port): - with pytest.raises(Exception): - inp_port.value_length = (2+4j) - -def test_set_value_length_float(inp_port): - with pytest.raises(Exception): - inp_port.value_length = 3.2 - -def test_set_value_length_pos_then_none(inp_port): - inp_port.value_length = 10 - inp_port.value_length = None - assert inp_port.value_length is None diff --git a/test/test_operation.py b/test/test_operation.py index c3a05bb5a08fa443753c2bafcf2b035274098455..b76ba16d11425c0ce868e4fa0b4c88d9f862e23f 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,9 +1,6 @@ -from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly -from b_asic.signal import Signal -from b_asic.port import InputPort, OutputPort - import pytest +from b_asic import Constant, Addition class TestTraverse: def test_traverse_single_tree(self, operation): @@ -13,19 +10,16 @@ class TestTraverse: def test_traverse_tree(self, operation_tree): """Traverse a basic addition tree with two constants.""" - assert len(list(operation_tree.traverse())) == 3 + assert len(list(operation_tree.traverse())) == 5 def test_traverse_large_tree(self, large_operation_tree): """Traverse a larger tree.""" - assert len(list(large_operation_tree.traverse())) == 7 + assert len(list(large_operation_tree.traverse())) == 13 def test_traverse_type(self, large_operation_tree): - traverse = list(large_operation_tree.traverse()) - assert len( - list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 - assert len( - list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 + result = list(large_operation_tree.traverse()) + assert len(list(filter(lambda type_: isinstance(type_, Addition), result))) == 3 + assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 - def test_traverse_loop(self, operation_tree): - # TODO: Construct a graph that contains a loop and make sure you can traverse it properly. - assert True + def test_traverse_loop(self, operation_graph_with_cycle): + assert len(list(operation_graph_with_cycle.traverse())) == 8 \ No newline at end of file diff --git a/test/test_outputport.py b/test/test_outputport.py index 21f08764ac4d7f9497dc02615cce343120598959..189c89225f88f263294d24aae21995ee7a821ada 100644 --- a/test/test_outputport.py +++ b/test/test_outputport.py @@ -1,9 +1,11 @@ """ B-ASIC test suite for OutputPort. """ -from b_asic import OutputPort, InputPort, Signal import pytest +from b_asic import OutputPort, InputPort, Signal + + @pytest.fixture def output_port(): return OutputPort(None, 0) @@ -16,6 +18,7 @@ def input_port(): def list_of_input_ports(): return [InputPort(None, i) for i in range(0, 3)] + class TestConnect: def test_multiple_ports(self, output_port, list_of_input_ports): """Can multiple ports connect to an output port?""" diff --git a/test/test_print_sfg.py b/test/test_print_sfg.py index feb3626e4791bd2d67f4711abd9a108c3cc0aec8..49b0950d82857f86ba652e76075b5d3cb40e1584 100644 --- a/test/test_print_sfg.py +++ b/test/test_print_sfg.py @@ -4,7 +4,7 @@ B-ASIC test suite for printing a SFG from b_asic.signal_flow_graph import SFG -from b_asic.core_operations import Addition, Multiplication, Constant, ConstantAddition +from b_asic.core_operations import Addition, Multiplication, Constant from b_asic.port import InputPort, OutputPort from b_asic.signal import Signal from b_asic.special_operations import Input, Output diff --git a/test/test_sfg.py b/test/test_sfg.py deleted file mode 100644 index af9dfe179751fd620d5880494215c3b1cfb8571b..0000000000000000000000000000000000000000 --- a/test/test_sfg.py +++ /dev/null @@ -1,116 +0,0 @@ -from b_asic import SFG -from b_asic.signal import Signal -from b_asic.core_operations import Addition, Constant, Multiplication -from b_asic.special_operations import Input, Output - - -class TestConstructor: - def test_direct_input_to_output_sfg_construction(self): - inp = Input("INP1") - out = Output(None, "OUT1") - out.input(0).connect(inp, "S1") - - sfg = SFG(inputs=[inp], outputs=[out]) - - assert len(list(sfg.components)) == 3 - assert sfg.input_count == 1 - assert sfg.output_count == 1 - - def test_same_signal_input_and_output_sfg_construction(self): - add1 = Addition(None, None, "ADD1") - add2 = Addition(None, None, "ADD2") - - sig1 = add2.input(0).connect(add1, "S1") - - sfg = SFG(input_signals=[sig1], output_signals=[sig1]) - - assert len(list(sfg.components)) == 3 - assert sfg.input_count == 1 - assert sfg.output_count == 1 - - def test_outputs_construction(self, operation_tree): - outp = Output(operation_tree) - sfg = SFG(outputs=[outp]) - - assert len(list(sfg.components)) == 7 - assert sfg.input_count == 0 - assert sfg.output_count == 1 - - def test_signals_construction(self, operation_tree): - outs = Signal(source=operation_tree.output(0)) - sfg = SFG(output_signals=[outs]) - - assert len(list(sfg.components)) == 7 - assert sfg.input_count == 0 - assert sfg.output_count == 1 - - -class TestDeepCopy: - def test_deep_copy_no_duplicates(self): - inp1 = Input("INP1") - inp2 = Input("INP2") - inp3 = Input("INP3") - add1 = Addition(inp1, inp2, "ADD1") - mul1 = Multiplication(add1, inp3, "MUL1") - out1 = Output(mul1, "OUT1") - - mac_sfg = SFG(inputs=[inp1, inp2], - outputs=[out1], name="mac_sfg") - - mac_sfg_deep_copy = mac_sfg.deep_copy() - - for g_id, component in mac_sfg._components_by_id.items(): - component_copy = mac_sfg_deep_copy.find_by_id(g_id) - assert component.name == component_copy.name - - def test_deep_copy(self): - inp1 = Input("INP1") - inp2 = Input("INP2") - inp3 = Input("INP3") - add1 = Addition(None, None, "ADD1") - add2 = Addition(None, None, "ADD2") - mul1 = Multiplication(None, None, "MUL1") - out1 = Output(None, "OUT1") - - add1.input(0).connect(inp1, "S1") - add1.input(1).connect(inp2, "S2") - add2.input(0).connect(add1, "S4") - add2.input(1).connect(inp3, "S3") - mul1.input(0).connect(add1, "S5") - mul1.input(1).connect(add2, "S6") - out1.input(0).connect(mul1, "S7") - - mac_sfg = SFG(inputs=[inp1, inp2], - outputs=[out1], name="mac_sfg") - - mac_sfg_deep_copy = mac_sfg.deep_copy() - - for g_id, component in mac_sfg._components_by_id.items(): - component_copy = mac_sfg_deep_copy.find_by_id(g_id) - assert component.name == component_copy.name - - -class TestComponents: - - def test_advanced_components(self): - inp1 = Input("INP1") - inp2 = Input("INP2") - inp3 = Input("INP3") - add1 = Addition(None, None, "ADD1") - add2 = Addition(None, None, "ADD2") - mul1 = Multiplication(None, None, "MUL1") - out1 = Output(None, "OUT1") - - add1.input(0).connect(inp1, "S1") - add1.input(1).connect(inp2, "S2") - add2.input(0).connect(add1, "S4") - add2.input(1).connect(inp3, "S3") - mul1.input(0).connect(add1, "S5") - mul1.input(1).connect(add2, "S6") - out1.input(0).connect(mul1, "S7") - - mac_sfg = SFG(inputs=[inp1, inp2], - outputs=[out1], name="mac_sfg") - - assert set([comp.name for comp in mac_sfg.components]) == { - "INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} diff --git a/test/test_signal.py b/test/test_signal.py index 9a45086a99e55089c9e25100cdd56399ca46a5cc..cad16c9ba5b73b3d597c4e80aa666677d1909888 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -2,11 +2,11 @@ B-ASIC test suit for the signal module which consists of the Signal class. """ -from b_asic.port import InputPort, OutputPort -from b_asic.signal import Signal - import pytest +from b_asic import InputPort, OutputPort, Signal + + def test_signal_creation_and_disconnction_and_connection_changing(): in_port = InputPort(None, 0) out_port = OutputPort(None, 1) @@ -60,3 +60,28 @@ def test_signal_creation_and_disconnction_and_connection_changing(): assert in_port.signals == [s] assert s.source is out_port assert s.destination is in_port + +def test_signal_set_bits_pos_int(signal): + signal.bits = 10 + assert signal.bits == 10 + +def test_signal_set_bits_zero(signal): + signal.bits = 0 + assert signal.bits == 0 + +def test_signal_set_bits_neg_int(signal): + with pytest.raises(Exception): + signal.bits = -10 + +def test_signal_set_bits_complex(signal): + with pytest.raises(Exception): + signal.bits = (2+4j) + +def test_signal_set_bits_float(signal): + with pytest.raises(Exception): + signal.bits = 3.2 + +def test_signal_set_bits_pos_then_none(signal): + signal.bits = 10 + signal.bits = None + assert signal.bits is None \ No newline at end of file diff --git a/test/test_signal_flow_graph.py b/test/test_signal_flow_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..51267cc44ece60c05e622c02c89b8ec1a5d5b17d --- /dev/null +++ b/test/test_signal_flow_graph.py @@ -0,0 +1,149 @@ +import pytest + +from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication + + +class TestConstructor: + def test_direct_input_to_output_sfg_construction(self): + in1 = Input("IN1") + out1 = Output(None, "OUT1") + out1.input(0).connect(in1, "S1") + + sfg = SFG(inputs = [in1], outputs = [out1]) # in1 ---s1---> out1 + + assert len(list(sfg.components)) == 3 + assert len(list(sfg.operations)) == 2 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + + def test_same_signal_input_and_output_sfg_construction(self): + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + s1 = add2.input(0).connect(add1, "S1") + + sfg = SFG(input_signals = [s1], output_signals = [s1]) # in1 ---s1---> out1 + + assert len(list(sfg.components)) == 3 + assert len(list(sfg.operations)) == 2 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + + def test_outputs_construction(self, operation_tree): + sfg = SFG(outputs = [Output(operation_tree)]) + + assert len(list(sfg.components)) == 7 + assert len(list(sfg.operations)) == 4 + assert sfg.input_count == 0 + assert sfg.output_count == 1 + + def test_signals_construction(self, operation_tree): + sfg = SFG(output_signals = [Signal(source = operation_tree.output(0))]) + + assert len(list(sfg.components)) == 7 + assert len(list(sfg.operations)) == 4 + assert sfg.input_count == 0 + assert sfg.output_count == 1 + + +class TestEvaluation: + def test_evaluate_output(self, operation_tree): + sfg = SFG(outputs = [Output(operation_tree)]) + assert sfg.evaluate_output(0, []) == 5 + + def test_evaluate_output_large(self, large_operation_tree): + sfg = SFG(outputs = [Output(large_operation_tree)]) + assert sfg.evaluate_output(0, []) == 14 + + def test_evaluate_output_cycle(self, operation_graph_with_cycle): + sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) + with pytest.raises(Exception): + sfg.evaluate_output(0, []) + + +class TestDeepCopy: + def test_deep_copy_no_duplicates(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + mac_sfg_new = mac_sfg() + + assert mac_sfg.name == "mac_sfg" + assert mac_sfg_new.name == "" + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_new.find_by_id(g_id) + assert component.name == component_copy.name + + def test_deep_copy(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], id_number_offset = 100, name = "mac_sfg") + mac_sfg_new = mac_sfg(name = "mac_sfg2") + + assert mac_sfg.name == "mac_sfg" + assert mac_sfg_new.name == "mac_sfg2" + assert mac_sfg.id_number_offset == 100 + assert mac_sfg_new.id_number_offset == 100 + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_new.find_by_id(g_id) + assert component.name == component_copy.name + + def test_deep_copy_with_new_sources(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + + a = Addition(Constant(3), Constant(5)) + b = Constant(2) + mac_sfg_new = mac_sfg(a, b) + assert mac_sfg_new.input(0).signals[0].source.operation is a + assert mac_sfg_new.input(1).signals[0].source.operation is b + + +class TestComponents: + def test_advanced_components(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + + assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} diff --git a/test/test_simulation.py b/test/test_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..faa1f75eb12acccf26169f31849f61da83df598e --- /dev/null +++ b/test/test_simulation.py @@ -0,0 +1,135 @@ +import pytest +import numpy as np + +from b_asic import SFG, Output, Simulation + + +class TestSimulation: + def test_simulate_with_lambdas_as_input(self, sfg_two_inputs_two_outputs): + simulation = Simulation(sfg_two_inputs_two_outputs, [lambda n: n + 3, lambda n: 1 + n * 2], save_results = True) + + output = simulation.run_for(101) + + assert output[0] == 304 + assert output[1] == 505 + + assert simulation.results[100]["0"] == 304 + assert simulation.results[100]["1"] == 505 + + assert simulation.results[0]["in1"] == 3 + assert simulation.results[0]["in2"] == 1 + assert simulation.results[0]["add1"] == 4 + assert simulation.results[0]["add2"] == 5 + assert simulation.results[0]["0"] == 4 + assert simulation.results[0]["1"] == 5 + + assert simulation.results[1]["in1"] == 4 + assert simulation.results[1]["in2"] == 3 + assert simulation.results[1]["add1"] == 7 + assert simulation.results[1]["add2"] == 10 + assert simulation.results[1]["0"] == 7 + assert simulation.results[1]["1"] == 10 + + assert simulation.results[2]["in1"] == 5 + assert simulation.results[2]["in2"] == 5 + assert simulation.results[2]["add1"] == 10 + assert simulation.results[2]["add2"] == 15 + assert simulation.results[2]["0"] == 10 + assert simulation.results[2]["1"] == 15 + + assert simulation.results[3]["in1"] == 6 + assert simulation.results[3]["in2"] == 7 + assert simulation.results[3]["add1"] == 13 + assert simulation.results[3]["add2"] == 20 + assert simulation.results[3]["0"] == 13 + assert simulation.results[3]["1"] == 20 + + def test_simulate_with_numpy_arrays_as_input(self, sfg_two_inputs_two_outputs): + input0 = np.array([5, 9, 25, -5, 7]) + input1 = np.array([7, 3, 3, 54, 2]) + simulation = Simulation(sfg_two_inputs_two_outputs, [input0, input1]) + simulation.save_results = True + + output = simulation.run_for(5) + + assert output[0] == 9 + assert output[1] == 11 + + assert simulation.results[0]["in1"] == 5 + assert simulation.results[0]["in2"] == 7 + assert simulation.results[0]["add1"] == 12 + assert simulation.results[0]["add2"] == 19 + assert simulation.results[0]["0"] == 12 + assert simulation.results[0]["1"] == 19 + + assert simulation.results[1]["in1"] == 9 + assert simulation.results[1]["in2"] == 3 + assert simulation.results[1]["add1"] == 12 + assert simulation.results[1]["add2"] == 15 + assert simulation.results[1]["0"] == 12 + assert simulation.results[1]["1"] == 15 + + assert simulation.results[2]["in1"] == 25 + assert simulation.results[2]["in2"] == 3 + assert simulation.results[2]["add1"] == 28 + assert simulation.results[2]["add2"] == 31 + assert simulation.results[2]["0"] == 28 + assert simulation.results[2]["1"] == 31 + + assert simulation.results[3]["in1"] == -5 + assert simulation.results[3]["in2"] == 54 + assert simulation.results[3]["add1"] == 49 + assert simulation.results[3]["add2"] == 103 + assert simulation.results[3]["0"] == 49 + assert simulation.results[3]["1"] == 103 + + assert simulation.results[4]["0"] == 9 + assert simulation.results[4]["1"] == 11 + + def test_simulate_with_numpy_array_overflow(self, sfg_two_inputs_two_outputs): + input0 = np.array([5, 9, 25, -5, 7]) + input1 = np.array([7, 3, 3, 54, 2]) + simulation = Simulation(sfg_two_inputs_two_outputs, [input0, input1]) + simulation.run_for(5) + with pytest.raises(IndexError): + simulation.run_for(1) + + def test_simulate_nested(self, sfg_nested): + input0 = np.array([5, 9]) + input1 = np.array([7, 3]) + simulation = Simulation(sfg_nested, [input0, input1]) + + output0 = simulation.run() + output1 = simulation.run() + + assert output0[0] == 11405 + assert output1[0] == 4221 + + def test_simulate_delay(self, sfg_delay): + simulation = Simulation(sfg_delay, save_results = True) + simulation.set_input(0, [5, -2, 25, -6, 7, 0]) + simulation.run_for(6) + + assert simulation.results[0]["0"] == 0 + assert simulation.results[1]["0"] == 5 + assert simulation.results[2]["0"] == -2 + assert simulation.results[3]["0"] == 25 + assert simulation.results[4]["0"] == -6 + assert simulation.results[5]["0"] == 7 + + def test_simulate_accumulator(self, sfg_accumulator): + data_in = np.array([5, -2, 25, -6, 7, 0]) + reset = np.array([0, 0, 0, 1, 0, 0]) + simulation = Simulation(sfg_accumulator, [data_in, reset]) + output0 = simulation.run() + output1 = simulation.run() + output2 = simulation.run() + output3 = simulation.run() + output4 = simulation.run() + output5 = simulation.run() + assert output0[0] == 0 + assert output1[0] == 5 + assert output2[0] == 3 + assert output3[0] == 28 + assert output4[0] == 0 + assert output5[0] == 7 \ No newline at end of file