Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
Commits on Source (84)
Showing
with 996 additions and 667 deletions
.vs/ .vs/
.vscode/ .vscode/
build*/ build*/
bin*/ bin*/
logs/ logs/
dist/ dist/
CMakeLists.txt.user* CMakeLists.txt.user*
*.autosave *.autosave
*.creator *.creator
*.creator.user* *.creator.user*
\#*\# \#*\#
/.emacs.desktop /.emacs.desktop
/.emacs.desktop.lock /.emacs.desktop.lock
*.elc *.elc
auto-save-list auto-save-list
tramp tramp
.\#* .\#*
*~ *~
.fuse_hudden* .fuse_hudden*
.directory .directory
.Trash-* .Trash-*
.nfs* .nfs*
Thumbs.db Thumbs.db
Thumbs.db:encryptable Thumbs.db:encryptable
ehthumbs.db ehthumbs.db
ehthumbs_vista.db ehthumbs_vista.db
$RECYCLE.BIN/ $RECYCLE.BIN/
*.stackdump *.stackdump
[Dd]esktop.ini [Dd]esktop.ini
*.egg-info *.egg-info
__pycache__/ __pycache__/
env/ env/
venv/ venv/
\ No newline at end of file
...@@ -4,7 +4,6 @@ TODO: More info. ...@@ -4,7 +4,6 @@ TODO: More info.
""" """
from b_asic.core_operations import * from b_asic.core_operations import *
from b_asic.graph_component import * from b_asic.graph_component import *
from b_asic.graph_id import *
from b_asic.operation import * from b_asic.operation import *
from b_asic.precedence_chart import * from b_asic.precedence_chart import *
from b_asic.port import * from b_asic.port import *
...@@ -12,3 +11,4 @@ from b_asic.schema import * ...@@ -12,3 +11,4 @@ from b_asic.schema import *
from b_asic.signal_flow_graph import * from b_asic.signal_flow_graph import *
from b_asic.signal import * from b_asic.signal import *
from b_asic.simulation import * from b_asic.simulation import *
from b_asic.special_operations import *
...@@ -4,43 +4,39 @@ TODO: More info. ...@@ -4,43 +4,39 @@ TODO: More info.
""" """
from numbers import Number from numbers import Number
from typing import Any from typing import Optional
from numpy import conjugate, sqrt, abs as np_abs from numpy import conjugate, sqrt, abs as np_abs
from b_asic.port import InputPort, OutputPort
from b_asic.graph_id import GraphIDType from b_asic.port import SignalSourceProvider, InputPort, OutputPort
from b_asic.operation import AbstractOperation from b_asic.operation import AbstractOperation
from b_asic.graph_component import Name, TypeName from b_asic.graph_component import Name, TypeName
class Input(AbstractOperation):
"""Input operation.
TODO: More info.
"""
# TODO: Implement all functions.
@property
def type_name(self) -> TypeName:
return "in"
class Constant(AbstractOperation): class Constant(AbstractOperation):
"""Constant value operation. """Constant value operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, value: Number = 0, name: Name = ""): def __init__(self, value: Number = 0, name: Name = ""):
super().__init__(name) super().__init__(input_count = 0, output_count = 1, name = name)
self.set_param("value", value)
self._output_ports = [OutputPort(0, self)] @property
self._parameters["value"] = value def type_name(self) -> TypeName:
return "c"
def evaluate(self): def evaluate(self):
return self.param("value") return self.param("value")
@property @property
def type_name(self) -> TypeName: def value(self) -> Number:
return "c" """TODO: docstring"""
return self.param("value")
@value.setter
def value(self, value: Number):
"""TODO: docstring"""
return self.set_param("value", value)
class Addition(AbstractOperation): class Addition(AbstractOperation):
...@@ -48,290 +44,228 @@ class Addition(AbstractOperation): ...@@ -48,290 +44,228 @@ class Addition(AbstractOperation):
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
if source2 is not None:
self._input_ports[1].connect(source2)
def evaluate(self, a, b):
return a + b
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "add" return "add"
def evaluate(self, a, b):
return a + b
class Subtraction(AbstractOperation): class Subtraction(AbstractOperation):
"""Binary subtraction operation. """Binary subtraction operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
if source2 is not None:
self._input_ports[1].connect(source2)
def evaluate(self, a, b):
return a - b
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "sub" return "sub"
def evaluate(self, a, b):
return a - b
class Multiplication(AbstractOperation): class Multiplication(AbstractOperation):
"""Binary multiplication operation. """Binary multiplication operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
if source2 is not None:
self._input_ports[1].connect(source2)
def evaluate(self, a, b):
return a * b
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "mul" return "mul"
def evaluate(self, a, b):
return a * b
class Division(AbstractOperation): class Division(AbstractOperation):
"""Binary division operation. """Binary division operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
if source2 is not None:
self._input_ports[1].connect(source2)
def evaluate(self, a, b):
return a / b
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "div" return "div"
def evaluate(self, a, b):
return a / b
class SquareRoot(AbstractOperation): class SquareRoot(AbstractOperation):
"""Unary square root operation. """Unary square root operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
def evaluate(self, a):
return sqrt((complex)(a))
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "sqrt" return "sqrt"
def evaluate(self, a):
return sqrt(complex(a))
class ComplexConjugate(AbstractOperation): class ComplexConjugate(AbstractOperation):
"""Unary complex conjugate operation. """Unary complex conjugate operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
def evaluate(self, a):
return conjugate(a)
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "conj" return "conj"
def evaluate(self, a):
return conjugate(a)
class Max(AbstractOperation): class Max(AbstractOperation):
"""Binary max operation. """Binary max operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None: @property
self._input_ports[0].connect(source1) def type_name(self) -> TypeName:
if source2 is not None: return "max"
self._input_ports[1].connect(source2)
def evaluate(self, a, b): def evaluate(self, a, b):
assert not isinstance(a, complex) and not isinstance(b, complex), \ assert not isinstance(a, complex) and not isinstance(b, complex), \
("core_operations.Max does not support complex numbers.") ("core_operations.Max does not support complex numbers.")
return a if a > b else b return a if a > b else b
@property
def type_name(self) -> TypeName:
return "max"
class Min(AbstractOperation): class Min(AbstractOperation):
"""Binary min operation. """Binary min operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1])
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None: @property
self._input_ports[0].connect(source1) def type_name(self) -> TypeName:
if source2 is not None: return "min"
self._input_ports[1].connect(source2)
def evaluate(self, a, b): def evaluate(self, a, b):
assert not isinstance(a, complex) and not isinstance(b, complex), \ assert not isinstance(a, complex) and not isinstance(b, complex), \
("core_operations.Min does not support complex numbers.") ("core_operations.Min does not support complex numbers.")
return a if a < b else b return a if a < b else b
@property
def type_name(self) -> TypeName:
return "min"
class Absolute(AbstractOperation): class Absolute(AbstractOperation):
"""Unary absolute value operation. """Unary absolute value operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, source1: OutputPort = None, name: Name = ""): def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)]
self._output_ports = [OutputPort(0, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
def evaluate(self, a):
return np_abs(a)
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "abs" return "abs"
def evaluate(self, a):
return np_abs(a)
class ConstantMultiplication(AbstractOperation): class ConstantMultiplication(AbstractOperation):
"""Unary constant multiplication operation. """Unary constant multiplication operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)] self.set_param("value", value)
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect(source1)
def evaluate(self, a):
return a * self.param("coefficient")
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "cmul" return "cmul"
def evaluate(self, a):
return a * self.param("value")
class ConstantAddition(AbstractOperation): class ConstantAddition(AbstractOperation):
"""Unary constant addition operation. """Unary constant addition operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)] self.set_param("value", value)
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect(source1)
def evaluate(self, a):
return a + self.param("coefficient")
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "cadd" return "cadd"
def evaluate(self, a):
return a + self.param("value")
class ConstantSubtraction(AbstractOperation): class ConstantSubtraction(AbstractOperation):
"""Unary constant subtraction operation. """Unary constant subtraction operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)] self.set_param("value", value)
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None:
self._input_ports[0].connect(source1)
def evaluate(self, a):
return a - self.param("coefficient")
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "csub" return "csub"
def evaluate(self, a):
return a - self.param("value")
class ConstantDivision(AbstractOperation): class ConstantDivision(AbstractOperation):
"""Unary constant division operation. """Unary constant division operation.
TODO: More info. TODO: More info.
""" """
def __init__(self, coefficient: Number, source1: OutputPort = None, name: Name = ""): def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(name) super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self._input_ports = [InputPort(0, self)] self.set_param("value", value)
self._output_ports = [OutputPort(0, self)]
self._parameters["coefficient"] = coefficient
if source1 is not None: @property
self._input_ports[0].connect(source1) def type_name(self) -> TypeName:
return "cdiv"
def evaluate(self, a): def evaluate(self, a):
return a / self.param("coefficient") return a / self.param("value")
class Butterfly(AbstractOperation):
"""Butterfly operation that returns two outputs.
The first output is a + b and the second output is a - b.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 2, output_count = 2, name = name, input_sources = [src0, src1])
def evaluate(self, a, b):
return a + b, a - b
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "cdiv" return "bfly"
...@@ -4,6 +4,7 @@ TODO: More info. ...@@ -4,6 +4,7 @@ TODO: More info.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import copy
from typing import NewType from typing import NewType
Name = NewType("Name", str) Name = NewType("Name", str)
...@@ -33,6 +34,11 @@ class GraphComponent(ABC): ...@@ -33,6 +34,11 @@ class GraphComponent(ABC):
"""Set the name of the graph component to the entered name.""" """Set the name of the graph component to the entered name."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def copy_unconnected(self) -> "GraphComponent":
"""Get a copy of this graph component, except without any connected components."""
raise NotImplementedError
class AbstractGraphComponent(GraphComponent): class AbstractGraphComponent(GraphComponent):
"""Abstract Graph Component class which is a component of a signal flow graph. """Abstract Graph Component class which is a component of a signal flow graph.
...@@ -52,3 +58,8 @@ class AbstractGraphComponent(GraphComponent): ...@@ -52,3 +58,8 @@ class AbstractGraphComponent(GraphComponent):
@name.setter @name.setter
def name(self, name: Name) -> None: def name(self, name: Name) -> None:
self._name = name self._name = name
def copy_unconnected(self) -> GraphComponent:
new_comp = self.__class__()
new_comp.name = copy(self.name)
return new_comp
\ No newline at end of file
"""@package docstring
B-ASIC Graph ID module for handling IDs of different objects in a graph.
TODO: More info
"""
from collections import defaultdict
from typing import NewType, DefaultDict
GraphID = NewType("GraphID", str)
GraphIDType = NewType("GraphIDType", str)
GraphIDNumber = NewType("GraphIDNumber", int)
class GraphIDGenerator:
"""A class that generates Graph IDs for objects."""
_next_id_number: DefaultDict[GraphIDType, GraphIDNumber]
def __init__(self):
self._next_id_number = defaultdict(lambda: 1) # Initalises every key element to 1
def get_next_id(self, graph_id_type: GraphIDType) -> GraphID:
"""Return the next graph id for a certain graph id type."""
graph_id = graph_id_type + str(self._next_id_number[graph_id_type])
self._next_id_number[graph_id_type] += 1 # Increase the current id number
return graph_id
...@@ -3,51 +3,86 @@ B-ASIC Operation Module. ...@@ -3,51 +3,86 @@ B-ASIC Operation Module.
TODO: More info. TODO: More info.
""" """
import collections
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy
from numbers import Number from numbers import Number
from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union
from collections import deque from collections import deque
from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.simulation import SimulationState, OperationState from b_asic.port import SignalSourceProvider, InputPort, OutputPort
from b_asic.signal import Signal
if TYPE_CHECKING:
from b_asic.port import InputPort, OutputPort
class Operation(GraphComponent): class Operation(GraphComponent, SignalSourceProvider):
"""Operation interface. """Operation interface.
TODO: More info. TODO: More info.
""" """
@abstractmethod @abstractmethod
def inputs(self) -> "List[InputPort]": def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
"""Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantAddition operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
"""Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantSubtraction operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
"""Overloads the multiplication operator to make it return a new Multiplication operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantMultiplication operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantDivision operation object instead.
"""
raise NotImplementedError
@property
@abstractmethod
def inputs(self) -> List[InputPort]:
"""Get a list of all input ports.""" """Get a list of all input ports."""
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod @abstractmethod
def outputs(self) -> "List[OutputPort]": def outputs(self) -> List[OutputPort]:
"""Get a list of all output ports.""" """Get a list of all output ports."""
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod @abstractmethod
def input_count(self) -> int: def input_count(self) -> int:
"""Get the number of input ports.""" """Get the number of input ports."""
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod @abstractmethod
def output_count(self) -> int: def output_count(self) -> int:
"""Get the number of output ports.""" """Get the number of output ports."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def input(self, i: int) -> "InputPort": def input(self, i: int) -> InputPort:
"""Get the input port at index i.""" """Get the input port at index i."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def output(self, i: int) -> "OutputPort": def output(self, i: int) -> OutputPort:
"""Get the output port at index i.""" """Get the output port at index i."""
raise NotImplementedError raise NotImplementedError
...@@ -66,19 +101,23 @@ class Operation(GraphComponent): ...@@ -66,19 +101,23 @@ class Operation(GraphComponent):
@abstractmethod @abstractmethod
def set_param(self, name: str, value: Any) -> None: def set_param(self, name: str, value: Any) -> None:
"""Set the value of a parameter. """Set the value of a parameter.
The parameter must be defined. Adds the parameter if it is not already defined.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def evaluate_outputs(self, state: "SimulationState") -> List[Number]: def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
"""Simulate the circuit until its iteration count matches that of the simulation state, """Evaluate the output at index i of this operation with the given input values.
then return the resulting output vector. The returned sequence contains results corresponding to each output of this operation,
where a value of None means it was not evaluated.
The value at index i is guaranteed to have been evaluated, while the others may or may not
have been evaluated depending on what is the most efficient.
For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1).
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def split(self) -> "List[Operation]": def split(self) -> Iterable["Operation"]:
"""Split the operation into multiple operations. """Split the operation into multiple operations.
If splitting is not possible, this may return a list containing only the operation itself. If splitting is not possible, this may return a list containing only the operation itself.
""" """
...@@ -86,28 +125,53 @@ class Operation(GraphComponent): ...@@ -86,28 +125,53 @@ class Operation(GraphComponent):
@property @property
@abstractmethod @abstractmethod
def neighbors(self) -> "List[Operation]": def neighbors(self) -> Iterable["Operation"]:
"""Return all operations that are connected by signals to this operation. """Return all operations that are connected by signals to this operation.
If no neighbors are found, this returns an empty list. If no neighbors are found, this returns an empty list.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def traverse(self) -> Generator["Operation", None, None]:
"""Get a generator that recursively iterates through all operations that are connected by signals to this operation,
as well as the ones that they are connected to.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent): class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from. """Generic abstract operation class which most implementations will derive from.
TODO: More info. TODO: More info.
""" """
_input_ports: List["InputPort"] _input_ports: List[InputPort]
_output_ports: List["OutputPort"] _output_ports: List[OutputPort]
_parameters: Dict[str, Optional[Any]] _parameters: Dict[str, Optional[Any]]
def __init__(self, name: Name = ""): def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
super().__init__(name) super().__init__(name)
self._input_ports = [] self._input_ports = []
self._output_ports = [] self._output_ports = []
self._parameters = {} self._parameters = {}
# Allocate input ports.
for i in range(input_count):
self._input_ports.append(InputPort(self, i))
# Allocate output ports.
for i in range(output_count):
self._output_ports.append(OutputPort(self, i))
# Connect given input sources, if any.
if input_sources is not None:
source_count = len(input_sources)
if source_count != input_count:
raise ValueError(
f"Operation expected {input_count} input sources but only got {source_count}")
for i, src in enumerate(input_sources):
if src is not None:
self._input_ports[i].connect(src.source)
@abstractmethod @abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a """Evaluate the operation and generate a list of output values given a
...@@ -115,24 +179,61 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -115,24 +179,61 @@ class AbstractOperation(Operation, AbstractGraphComponent):
""" """
raise NotImplementedError raise NotImplementedError
def inputs(self) -> List["InputPort"]: def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, ConstantAddition
if isinstance(src, Number):
return ConstantAddition(src, self)
return Addition(self, src)
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
# Import here to avoid circular imports.
from b_asic.core_operations import Subtraction, ConstantSubtraction
if isinstance(src, Number):
return ConstantSubtraction(src, self)
return Subtraction(self, src)
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
# Import here to avoid circular imports.
from b_asic.core_operations import Multiplication, ConstantMultiplication
if isinstance(src, Number):
return ConstantMultiplication(src, self)
return Multiplication(self, src)
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
# Import here to avoid circular imports.
from b_asic.core_operations import Division, ConstantDivision
if isinstance(src, Number):
return ConstantDivision(src, self)
return Division(self, src)
@property
def inputs(self) -> List[InputPort]:
return self._input_ports.copy() return self._input_ports.copy()
def outputs(self) -> List["OutputPort"]: @property
def outputs(self) -> List[OutputPort]:
return self._output_ports.copy() return self._output_ports.copy()
@property
def input_count(self) -> int: def input_count(self) -> int:
return len(self._input_ports) return len(self._input_ports)
@property
def output_count(self) -> int: def output_count(self) -> int:
return len(self._output_ports) return len(self._output_ports)
def input(self, i: int) -> "InputPort": def input(self, i: int) -> InputPort:
return self._input_ports[i] return self._input_ports[i]
def output(self, i: int) -> "OutputPort": def output(self, i: int) -> OutputPort:
return self._output_ports[i] return self._output_ports[i]
@property
def params(self) -> Dict[str, Optional[Any]]: def params(self) -> Dict[str, Optional[Any]]:
return self._parameters.copy() return self._parameters.copy()
...@@ -140,63 +241,51 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -140,63 +241,51 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return self._parameters.get(name) return self._parameters.get(name)
def set_param(self, name: str, value: Any) -> None: def set_param(self, name: str, value: Any) -> None:
assert name in self._parameters # TODO: Error message.
self._parameters[name] = value self._parameters[name] = value
def evaluate_outputs(self, state: SimulationState) -> List[Number]: def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
# TODO: Check implementation. result = self.evaluate(*input_values)
input_count: int = self.input_count() if isinstance(result, collections.Sequence):
output_count: int = self.output_count() if len(result) != self.output_count:
assert input_count == len(self._input_ports) # TODO: Error message. raise RuntimeError(
assert output_count == len(self._output_ports) # TODO: Error message. "Operation evaluated to incorrect number of outputs")
return result
self_state: OperationState = state.operation_states[self] if isinstance(result, Number):
if self.output_count != 1:
while self_state.iteration < state.iteration: raise RuntimeError(
input_values: List[Number] = [0] * input_count "Operation evaluated to incorrect number of outputs")
for i in range(input_count): return [result]
source: Signal = self._input_ports[i].signal raise RuntimeError("Operation evaluated to invalid type")
input_values[i] = source.operation.evaluate_outputs(state)[
source.port_index] def split(self) -> Iterable[Operation]:
# Import here to avoid circular imports.
self_state.output_values = self.evaluate(input_values) from b_asic.special_operations import Input
# TODO: Error message. try:
assert len(self_state.output_values) == output_count result = self.evaluate([Input()] * self.input_count)
self_state.iteration += 1 if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
for i in range(output_count): return result
for signal in self._output_ports[i].signals(): if isinstance(result, Operation):
destination: Signal = signal.destination return [result]
destination.evaluate_outputs(state) except TypeError:
pass
return self_state.output_values except ValueError:
pass
def split(self) -> List[Operation]:
# TODO: Check implementation.
results = self.evaluate(self._input_ports)
if all(isinstance(e, Operation) for e in results):
return results
return [self] return [self]
@property @property
def neighbors(self) -> List[Operation]: def neighbors(self) -> Iterable[Operation]:
neighbors: List[Operation] = [] neighbors = []
for port in self._input_ports: for port in self._input_ports:
for signal in port.signals: for signal in port.signals:
neighbors.append(signal.source.operation) neighbors.append(signal.source.operation)
for port in self._output_ports: for port in self._output_ports:
for signal in port.signals: for signal in port.signals:
neighbors.append(signal.destination.operation) neighbors.append(signal.destination.operation)
return neighbors return neighbors
def traverse(self) -> Operation: def traverse(self) -> Generator[Operation, None, None]:
"""Traverse the operation tree and return a generator with start point in the operation.""" # Breadth first search.
return self._breadth_first_search() visited = {self}
def _breadth_first_search(self) -> Operation:
"""Use breadth first search to traverse the operation tree."""
visited: Set[Operation] = {self}
queue = deque([self]) queue = deque([self])
while queue: while queue:
operation = queue.popleft() operation = queue.popleft()
...@@ -206,63 +295,17 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -206,63 +295,17 @@ class AbstractOperation(Operation, AbstractGraphComponent):
visited.add(n_operation) visited.add(n_operation)
queue.append(n_operation) queue.append(n_operation)
def __add__(self, other): @property
"""Overloads the addition operator to make it return a new Addition operation def source(self) -> OutputPort:
object that is connected to the self and other objects. If other is a number then if self.output_count != 1:
returns a ConstantAddition operation object instead. diff = "more" if self.output_count > 1 else "less"
""" raise TypeError(
# Import here to avoid circular imports. f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
from b_asic.core_operations import Addition, ConstantAddition return self.output(0)
if isinstance(other, Operation): def copy_unconnected(self) -> GraphComponent:
return Addition(self.output(0), other.output(0)) new_comp: AbstractOperation = super().copy_unconnected()
elif isinstance(other, Number): for name, value in self.params.items():
return ConstantAddition(other, self.output(0)) new_comp.set_param(name, deepcopy(
else: value)) # pylint: disable=no-member
raise TypeError("Other type is not an Operation or a Number.") return new_comp
def __sub__(self, other):
"""Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantSubtraction operation object instead.
"""
# Import here to avoid circular imports.
from b_asic.core_operations import Subtraction, ConstantSubtraction
if isinstance(other, Operation):
return Subtraction(self.output(0), other.output(0))
elif isinstance(other, Number):
return ConstantSubtraction(other, self.output(0))
else:
raise TypeError("Other type is not an Operation or a Number.")
def __mul__(self, other):
"""Overloads the multiplication operator to make it return a new Multiplication operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantMultiplication operation object instead.
"""
# Import here to avoid circular imports.
from b_asic.core_operations import Multiplication, ConstantMultiplication
if isinstance(other, Operation):
return Multiplication(self.output(0), other.output(0))
elif isinstance(other, Number):
return ConstantMultiplication(other, self.output(0))
else:
raise TypeError("Other type is not an Operation or a Number.")
def __truediv__(self, other):
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantDivision operation object instead.
"""
# Import here to avoid circular imports.
from b_asic.core_operations import Division, ConstantDivision
if isinstance(other, Operation):
return Division(self.output(0), other.output(0))
elif isinstance(other, Number):
return ConstantDivision(other, self.output(0))
else:
raise TypeError("Other type is not an Operation or a Number.")
...@@ -4,12 +4,15 @@ TODO: More info. ...@@ -4,12 +4,15 @@ TODO: More info.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import NewType, Optional, List from copy import copy
from typing import NewType, Optional, List, Iterable, TYPE_CHECKING
from b_asic.operation import Operation
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.graph_component import Name
if TYPE_CHECKING:
from b_asic.operation import Operation
PortIndex = NewType("PortIndex", int)
class Port(ABC): class Port(ABC):
"""Port Interface. """Port Interface.
...@@ -19,59 +22,33 @@ class Port(ABC): ...@@ -19,59 +22,33 @@ class Port(ABC):
@property @property
@abstractmethod @abstractmethod
def operation(self) -> Operation: def operation(self) -> "Operation":
"""Return the connected operation.""" """Return the connected operation."""
raise NotImplementedError raise NotImplementedError
@property @property
@abstractmethod @abstractmethod
def index(self) -> PortIndex: def index(self) -> int:
"""Return the unique PortIndex.""" """Return the index of the port."""
raise NotImplementedError
@property
@abstractmethod
def signals(self) -> List[Signal]:
"""Return a list of all connected signals."""
raise NotImplementedError
@abstractmethod
def signal(self, i: int = 0) -> Signal:
"""Return the connected signal at index i.
Keyword argumens:
i: integer index of the signal requsted.
"""
raise NotImplementedError raise NotImplementedError
@property @property
@abstractmethod
def connected_ports(self) -> List["Port"]:
"""Return a list of all connected Ports."""
raise NotImplementedError
@abstractmethod @abstractmethod
def signal_count(self) -> int: def signal_count(self) -> int:
"""Return the number of connected signals.""" """Return the number of connected signals."""
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod @abstractmethod
def connect(self, port: "Port") -> Signal: def signals(self) -> Iterable[Signal]:
"""Create and return a signal that is connected to this port and the entered """Return all connected signals."""
port and connect this port to the signal and the entered port to the signal."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def add_signal(self, signal: Signal) -> None: def add_signal(self, signal: Signal) -> None:
"""Connect this port to the entered signal. If the entered signal isn't connected to """Connect this port to the entered signal. If the entered signal isn't connected to
this port then connect the entered signal to the port aswell.""" this port then connect the entered signal to the port aswell.
raise NotImplementedError """
@abstractmethod
def disconnect(self, port: "Port") -> None:
"""Disconnect the entered port from the port by removing it from the ports signal.
If the entered port is still connected to this ports signal then disconnect the entered
port from the signal aswell."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -97,127 +74,131 @@ class AbstractPort(Port): ...@@ -97,127 +74,131 @@ class AbstractPort(Port):
Handles functionality for port id and saves the connection to the parent operation. Handles functionality for port id and saves the connection to the parent operation.
""" """
_operation: "Operation"
_index: int _index: int
_operation: Operation
def __init__(self, index: int, operation: Operation): def __init__(self, operation: "Operation", index: int):
self._index = index
self._operation = operation self._operation = operation
self._index = index
@property @property
def operation(self) -> Operation: def operation(self) -> "Operation":
return self._operation return self._operation
@property @property
def index(self) -> PortIndex: def index(self) -> int:
return self._index return self._index
class SignalSourceProvider(ABC):
"""Signal source provider interface.
TODO: More info.
"""
@property
@abstractmethod
def source(self) -> "OutputPort":
"""Get the main source port provided by this object."""
raise NotImplementedError
class InputPort(AbstractPort): class InputPort(AbstractPort):
"""Input port. """Input port.
TODO: More info. TODO: More info.
""" """
_source_signal: Optional[Signal] _source_signal: Optional[Signal]
_value_length: Optional[int]
def __init__(self, port_id: PortIndex, operation: Operation): def __init__(self, operation: "Operation", index: int):
super().__init__(port_id, operation) super().__init__(operation, index)
self._source_signal = None self._source_signal = None
self._value_length = None
@property @property
def signals(self) -> List[Signal]:
return [] if self._source_signal is None else [self._source_signal]
def signal(self, i: int = 0) -> Signal:
assert 0 <= i < self.signal_count(), "Signal index out of bound."
assert self._source_signal is not None, "No Signal connect to InputPort."
return self._source_signal
@property
def connected_ports(self) -> List[Port]:
return [] if self._source_signal is None or self._source_signal.source is None \
else [self._source_signal.source]
def signal_count(self) -> int: def signal_count(self) -> int:
return 0 if self._source_signal is None else 1 return 0 if self._source_signal is None else 1
def connect(self, port: "OutputPort") -> Signal: @property
assert self._source_signal is None, "Connecting new port to already connected input port." def signals(self) -> Iterable[Signal]:
return Signal(port, self) # self._source_signal is set by the signal constructor return [] if self._source_signal is None else [self._source_signal]
def add_signal(self, signal: Signal) -> None: def add_signal(self, signal: Signal) -> None:
assert self._source_signal is None, "Connecting new port to already connected input port." assert self._source_signal is None, "Input port may have only one signal added."
self._source_signal: Signal = signal assert signal is not self._source_signal, "Attempted to add already connected signal."
if self is not signal.destination: self._source_signal = signal
# Connect this inputport as destination for this signal if it isn't already. signal.set_destination(self)
signal.set_destination(self)
def disconnect(self, port: "OutputPort") -> None:
assert self._source_signal.source is port, "The entered port is not connected to this port."
self._source_signal.remove_source()
def remove_signal(self, signal: Signal) -> None: def remove_signal(self, signal: Signal) -> None:
old_signal: Signal = self._source_signal assert signal is self._source_signal, "Attempted to remove already removed signal."
self._source_signal = None self._source_signal = None
if self is old_signal.destination: signal.remove_destination()
# Disconnect the dest of the signal if this inputport currently is the dest
old_signal.remove_destination()
def clear(self) -> None: def clear(self) -> None:
self.remove_signal(self._source_signal) if self._source_signal is not None:
self.remove_signal(self._source_signal)
@property
def connected_source(self) -> Optional["OutputPort"]:
"""Get the output port that is currently connected to this input port,
or None if it is unconnected.
"""
return None if self._source_signal is None else self._source_signal.source
def connect(self, src: SignalSourceProvider, name: Name = "") -> Signal:
"""Connect the provided signal source to this input port by creating a new signal.
Returns the new signal.
"""
assert self._source_signal is None, "Attempted to connect already connected input port."
# self._source_signal is set by the signal constructor.
return Signal(source=src.source, destination=self, name=name)
@property
def value_length(self) -> Optional[int]:
"""Get the number of bits that this port should truncate received values to."""
return self._value_length
@value_length.setter
def value_length(self, bits: Optional[int]) -> None:
"""Set the number of bits that this port should truncate received values to."""
assert bits is None or (isinstance(
bits, int) and bits >= 0), "Value length must be non-negative."
self._value_length = bits
class OutputPort(AbstractPort):
class OutputPort(AbstractPort, SignalSourceProvider):
"""Output port. """Output port.
TODO: More info. TODO: More info.
""" """
_destination_signals: List[Signal] _destination_signals: List[Signal]
def __init__(self, port_id: PortIndex, operation: Operation): def __init__(self, operation: "Operation", index: int):
super().__init__(port_id, operation) super().__init__(operation, index)
self._destination_signals = [] self._destination_signals = []
@property @property
def signals(self) -> List[Signal]:
return self._destination_signals.copy()
def signal(self, i: int = 0) -> Signal:
assert 0 <= i < self.signal_count(), "Signal index out of bounds."
return self._destination_signals[i]
@property
def connected_ports(self) -> List[Port]:
return [signal.destination for signal in self._destination_signals \
if signal.destination is not None]
def signal_count(self) -> int: def signal_count(self) -> int:
return len(self._destination_signals) return len(self._destination_signals)
def connect(self, port: InputPort) -> Signal: @property
return Signal(self, port) # Signal is added to self._destination_signals in signal constructor def signals(self) -> Iterable[Signal]:
return self._destination_signals
def add_signal(self, signal: Signal) -> None: def add_signal(self, signal: Signal) -> None:
assert signal not in self.signals, \ assert signal not in self._destination_signals, "Attempted to add already connected signal."
"Attempting to connect to Signal already connected."
self._destination_signals.append(signal) self._destination_signals.append(signal)
if self is not signal.source: signal.set_source(self)
# Connect this outputport to the signal if it isn't already
signal.set_source(self)
def disconnect(self, port: InputPort) -> None:
assert port in self.connected_ports, "Attempting to disconnect port that isn't connected."
for sig in self._destination_signals:
if sig.destination is port:
sig.remove_destination()
break
def remove_signal(self, signal: Signal) -> None: def remove_signal(self, signal: Signal) -> None:
i: int = self._destination_signals.index(signal) assert signal in self._destination_signals, "Attempted to remove already removed signal."
old_signal: Signal = self._destination_signals[i] self._destination_signals.remove(signal)
del self._destination_signals[i] signal.remove_source()
if self is old_signal.source:
old_signal.remove_source()
def clear(self) -> None: def clear(self) -> None:
for signal in self._destination_signals: for signal in copy(self._destination_signals):
self.remove_signal(signal) self.remove_signal(signal)
@property
def source(self) -> "OutputPort":
return self
...@@ -12,30 +12,26 @@ if TYPE_CHECKING: ...@@ -12,30 +12,26 @@ if TYPE_CHECKING:
class Signal(AbstractGraphComponent): class Signal(AbstractGraphComponent):
"""A connection between two ports.""" """A connection between two ports."""
_source: "OutputPort" _source: Optional["OutputPort"]
_destination: "InputPort" _destination: Optional["InputPort"]
def __init__(self, source: Optional["OutputPort"] = None, \ def __init__(self, source: Optional["OutputPort"] = None, \
destination: Optional["InputPort"] = None, name: Name = ""): destination: Optional["InputPort"] = None, name: Name = ""):
super().__init__(name) super().__init__(name)
self._source = None
self._source = source self._destination = None
self._destination = destination
if source is not None: if source is not None:
self.set_source(source) self.set_source(source)
if destination is not None: if destination is not None:
self.set_destination(destination) self.set_destination(destination)
@property @property
def source(self) -> "OutputPort": def source(self) -> Optional["OutputPort"]:
"""Return the source OutputPort of the signal.""" """Return the source OutputPort of the signal."""
return self._source return self._source
@property @property
def destination(self) -> "InputPort": def destination(self) -> Optional["InputPort"]:
"""Return the destination "InputPort" of the signal.""" """Return the destination "InputPort" of the signal."""
return self._destination return self._destination
...@@ -47,11 +43,11 @@ class Signal(AbstractGraphComponent): ...@@ -47,11 +43,11 @@ class Signal(AbstractGraphComponent):
Keyword arguments: Keyword arguments:
- src: OutputPort to connect as source to the signal. - src: OutputPort to connect as source to the signal.
""" """
self.remove_source() if src is not self._source:
self._source = src self.remove_source()
if self not in src.signals: self._source = src
# If the new source isn't connected to this signal then connect it. if self not in src.signals:
src.add_signal(self) src.add_signal(self)
def set_destination(self, dest: "InputPort") -> None: def set_destination(self, dest: "InputPort") -> None:
"""Disconnect the previous destination InputPort of the signal and """Disconnect the previous destination InputPort of the signal and
...@@ -61,11 +57,11 @@ class Signal(AbstractGraphComponent): ...@@ -61,11 +57,11 @@ class Signal(AbstractGraphComponent):
Keywords argments: Keywords argments:
- dest: InputPort to connect as destination to the signal. - dest: InputPort to connect as destination to the signal.
""" """
self.remove_destination() if dest is not self._destination:
self._destination = dest self.remove_destination()
if self not in dest.signals: self._destination = dest
# If the new destination isn't connected to tis signal then connect it. if self not in dest.signals:
dest.add_signal(self) dest.add_signal(self)
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
...@@ -74,23 +70,21 @@ class Signal(AbstractGraphComponent): ...@@ -74,23 +70,21 @@ class Signal(AbstractGraphComponent):
def remove_source(self) -> None: def remove_source(self) -> None:
"""Disconnect the source OutputPort of the signal. If the source port """Disconnect the source OutputPort of the signal. If the source port
still is connected to this signal then also disconnect the source port.""" still is connected to this signal then also disconnect the source port."""
if self._source is not None: src = self._source
old_source: "OutputPort" = self._source if src is not None:
self._source = None self._source = None
if self in old_source.signals: if self in src.signals:
# If the old destination port still is connected to this signal, then disconnect it. src.remove_signal(self)
old_source.remove_signal(self)
def remove_destination(self) -> None: def remove_destination(self) -> None:
"""Disconnect the destination InputPort of the signal.""" """Disconnect the destination InputPort of the signal."""
if self._destination is not None: dest = self._destination
old_destination: "InputPort" = self._destination if dest is not None:
self._destination = None self._destination = None
if self in old_destination.signals: if self in dest.signals:
# If the old destination port still is connected to this signal, then disconnect it. dest.remove_signal(self)
old_destination.remove_signal(self)
def is_connected(self) -> bool: def dangling(self) -> bool:
"""Returns true if the signal is connected to both a source and a destination, """Returns true if the signal is missing either a source or a destination,
else false.""" else false."""
return self._source is not None and self._destination is not None return self._source is None or self._destination is None
...@@ -3,14 +3,33 @@ B-ASIC Signal Flow Graph Module. ...@@ -3,14 +3,33 @@ B-ASIC Signal Flow Graph Module.
TODO: More info. TODO: More info.
""" """
from typing import List, Dict, Optional, DefaultDict from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
from collections import defaultdict from numbers import Number
from collections import defaultdict, deque
from b_asic.operation import Operation from b_asic.port import SignalSourceProvider, OutputPort
from b_asic.operation import AbstractOperation from b_asic.operation import Operation, AbstractOperation
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.graph_id import GraphIDGenerator, GraphID
from b_asic.graph_component import GraphComponent, Name, TypeName from b_asic.graph_component import GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output
GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)
class GraphIDGenerator:
"""A class that generates Graph IDs for objects."""
_next_id_number: DefaultDict[TypeName, GraphIDNumber]
def __init__(self, id_number_offset: GraphIDNumber = 0):
self._next_id_number = defaultdict(lambda: id_number_offset)
def next_id(self, type_name: TypeName) -> GraphID:
"""Return the next graph id for a certain graph id type."""
self._next_id_number[type_name] += 1
return type_name + str(self._next_id_number[type_name])
class SFG(AbstractOperation): class SFG(AbstractOperation):
...@@ -18,51 +37,162 @@ class SFG(AbstractOperation): ...@@ -18,51 +37,162 @@ class SFG(AbstractOperation):
TODO: More info. TODO: More info.
""" """
_graph_components_by_id: Dict[GraphID, GraphComponent] _components_by_id: Dict[GraphID, GraphComponent]
_graph_components_by_name: DefaultDict[Name, List[GraphComponent]] _components_by_name: DefaultDict[Name, List[GraphComponent]]
_graph_id_generator: GraphIDGenerator _graph_id_generator: GraphIDGenerator
_input_operations: List[Input]
_output_operations: List[Output]
_original_components_added: Set[GraphComponent]
_original_input_signals: Dict[Signal, int]
_original_output_signals: Dict[Signal, int]
def __init__(self, input_signals: List[Signal] = None, output_signals: List[Signal] = None, \ def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [],
ops: List[Operation] = None, **kwds): inputs: Sequence[Input] = [], outputs: Sequence[Output] = [],
super().__init__(**kwds) id_number_offset: GraphIDNumber = 0, name: Name = "",
if input_signals is None: input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
input_signals = [] super().__init__(
if output_signals is None: input_count=len(input_signals) + len(inputs),
output_signals = [] output_count=len(output_signals) + len(outputs),
if ops is None: name=name,
ops = [] input_sources=input_sources)
self._graph_components_by_id = dict() # Maps Graph ID to objects self._components_by_id = dict()
self._graph_components_by_name = defaultdict(list) # Maps Name to objects self._components_by_name = defaultdict(list)
self._graph_id_generator = GraphIDGenerator() self._components_in_dfs_order = []
self._graph_id_generator = GraphIDGenerator(id_number_offset)
self._input_operations = []
self._output_operations = []
# Maps original components to new copied components
self._added_components_mapping = {}
self._original_input_signals_indexes = {}
self._original_output_signals_indexes = {}
self._id_number_offset = id_number_offset
for operation in ops: # Setup input signals.
self._add_graph_component(operation) for input_index, sig in enumerate(input_signals):
assert sig not in self._added_components_mapping, "Duplicate input signals sent to SFG construcctor."
for input_signal in input_signals: new_input_op = self._add_component_copy_unconnected(Input())
self._add_graph_component(input_signal) new_sig = self._add_component_copy_unconnected(sig)
new_sig.set_source(new_input_op.output(0))
# TODO: Construct SFG based on what inputs that were given self._input_operations.append(new_input_op)
# TODO: Traverse the graph between the inputs/outputs and add to self._operations. self._original_input_signals_indexes[sig] = input_index
# TODO: Connect ports with signals with appropriate IDs.
def evaluate(self, *inputs) -> list: # Setup input operations, starting from indexes ater input signals.
return [] # TODO: Implement for input_index, input_op in enumerate(inputs, len(input_signals)):
assert input_op not in self._added_components_mapping, "Duplicate input operations sent to SFG constructor."
new_input_op = self._add_component_copy_unconnected(input_op)
def _add_graph_component(self, graph_component: GraphComponent) -> GraphID: for sig in input_op.output(0).signals:
"""Add the entered graph component to the SFG's dictionary of graph objects and assert sig not in self._added_components_mapping, "Duplicate input signals connected to input ports sent to SFG construcctor."
return a generated GraphID for it. new_sig = self._add_component_copy_unconnected(sig)
new_sig.set_source(new_input_op.output(0))
Keyword arguments: self._original_input_signals_indexes[sig] = input_index
graph_component: Graph component to add to the graph.
""" self._input_operations.append(new_input_op)
# Add to name dict
self._graph_components_by_name[graph_component.name].append(graph_component) # Setup output signals.
for output_ind, sig in enumerate(output_signals):
new_out = self._add_component_copy_unconnected(Output())
if sig in self._added_components_mapping:
# Signal already added when setting up inputs
new_sig = self._added_components_mapping[sig]
new_sig.set_destination(new_out.input(0))
else:
# New signal has to be created
new_sig = self._add_component_copy_unconnected(sig)
new_sig.set_destination(new_out.input(0))
self._output_operations.append(new_out)
self._original_output_signals_indexes[sig] = output_ind
# Setup output operations, starting from indexes after output signals.
for output_ind, output_op in enumerate(outputs, len(output_signals)):
assert output_op not in self._added_components_mapping, "Duplicate output operations sent to SFG constructor."
new_out = self._add_component_copy_unconnected(output_op)
for sig in output_op.input(0).signals:
if sig in self._added_components_mapping:
# Signal already added when setting up inputs
new_sig = self._added_components_mapping[sig]
new_sig.set_destination(new_out.input(0))
else:
# New signal has to be created
new_sig = self._add_component_copy_unconnected(sig)
new_sig.set_destination(new_out.input(0))
self._original_output_signals_indexes[sig] = output_ind
self._output_operations.append(new_out)
output_operations_set = set(self._output_operations)
# Add to ID dict # Search the graph inwards from each input signal.
graph_id: GraphID = self._graph_id_generator.get_next_id(graph_component.type_name) for sig, input_index in self._original_input_signals_indexes.items():
self._graph_components_by_id[graph_id] = graph_component # Check if already added destination.
return graph_id new_sig = self._added_components_mapping[sig]
if new_sig.destination is None:
if sig.destination is None:
raise ValueError(
f"Input signal #{input_index} is missing destination in SFG")
elif sig.destination.operation not in self._added_components_mapping:
self._copy_structure_from_operation_dfs(
sig.destination.operation)
else:
if new_sig.destination.operation in output_operations_set:
# Add directly connected input to output to dfs order list
self._components_in_dfs_order.extend([
new_sig.source.operation, new_sig, new_sig.destination.operation])
# Search the graph inwards from each output signal.
for sig, output_index in self._original_output_signals_indexes.items():
# Check if already added source.
new_sig = self._added_components_mapping[sig]
if new_sig.source is None:
if sig.source is None:
raise ValueError(
f"Output signal #{output_index} is missing source in SFG")
if sig.source.operation not in self._added_components_mapping:
self._copy_structure_from_operation_dfs(
sig.source.operation)
def __call__(self):
return self.deep_copy()
@property
def type_name(self) -> TypeName:
return "sfg"
def evaluate(self, *args):
if len(args) != self.input_count:
raise ValueError(
"Wrong number of inputs supplied to SFG for evaluation")
for arg, op in zip(args, self._input_operations):
op.value = arg
result = []
for op in self._output_operations:
result.append(self._evaluate_source(op.input(0).signals[0].source))
n = len(result)
return None if n == 0 else result[0] if n == 1 else result
def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
assert i >= 0 and i < self.output_count, "Output index out of range"
result = [None] * self.output_count
result[i] = self._evaluate_source(
self._output_operations[i].input(0).signals[0].source)
return result
def split(self) -> Iterable[Operation]:
return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values())
@property
def components(self) -> Iterable[GraphComponent]:
"""Get all components of this graph in the dfs-traversal order."""
return self._components_in_dfs_order
def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
"""Find a graph object based on the entered Graph ID and return it. If no graph """Find a graph object based on the entered Graph ID and return it. If no graph
...@@ -71,10 +201,7 @@ class SFG(AbstractOperation): ...@@ -71,10 +201,7 @@ class SFG(AbstractOperation):
Keyword arguments: Keyword arguments:
graph_id: Graph ID of the wanted object. graph_id: Graph ID of the wanted object.
""" """
if graph_id in self._graph_components_by_id: return self._components_by_id.get(graph_id, None)
return self._graph_components_by_id[graph_id]
return None
def find_by_name(self, name: Name) -> List[GraphComponent]: def find_by_name(self, name: Name) -> List[GraphComponent]:
"""Find all graph objects that have the entered name and return them """Find all graph objects that have the entered name and return them
...@@ -84,8 +211,146 @@ class SFG(AbstractOperation): ...@@ -84,8 +211,146 @@ class SFG(AbstractOperation):
Keyword arguments: Keyword arguments:
name: Name of the wanted object. name: Name of the wanted object.
""" """
return self._graph_components_by_name[name] return self._components_by_name.get(name, [])
@property def deep_copy(self) -> "SFG":
def type_name(self) -> TypeName: """Returns a deep copy of self."""
return "sfg" copy = SFG(inputs=self._input_operations, outputs=self._output_operations,
id_number_offset=self._id_number_offset, name=super().name)
return copy
def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent:
assert original_comp not in self._added_components_mapping, "Tried to add duplicate SFG component"
new_comp = original_comp.copy_unconnected()
self._added_components_mapping[original_comp] = new_comp
self._components_by_id[self._graph_id_generator.next_id(
new_comp.type_name)] = new_comp
self._components_by_name[new_comp.name].append(new_comp)
return new_comp
def _copy_structure_from_operation_dfs(self, start_op: Operation):
op_stack = deque([start_op])
while op_stack:
original_op = op_stack.pop()
# Add or get the new copy of the operation..
new_op = None
if original_op not in self._added_components_mapping:
new_op = self._add_component_copy_unconnected(original_op)
self._components_in_dfs_order.append(new_op)
else:
new_op = self._added_components_mapping[original_op]
# Connect input ports to new signals
for original_input_port in original_op.inputs:
if original_input_port.signal_count < 1:
raise ValueError("Unconnected input port in SFG")
for original_signal in original_input_port.signals:
# Check if the signal is one of the SFG's input signals
if original_signal in self._original_input_signals_indexes:
# New signal already created during first step of constructor
new_signal = self._added_components_mapping[
original_signal]
new_signal.set_destination(
new_op.input(original_input_port.index))
self._components_in_dfs_order.extend(
[new_signal, new_signal.source.operation])
# Check if the signal has not been added before
elif original_signal not in self._added_components_mapping:
if original_signal.source is None:
raise ValueError(
"Dangling signal without source in SFG")
new_signal = self._add_component_copy_unconnected(
original_signal)
new_signal.set_destination(
new_op.input(original_input_port.index))
self._components_in_dfs_order.append(new_signal)
original_connected_op = original_signal.source.operation
# Check if connected Operation has been added before
if original_connected_op in self._added_components_mapping:
# Set source to the already added operations port
new_signal.set_source(
self._added_components_mapping[original_connected_op].output(
original_signal.source.index))
else:
# Create new operation, set signal source to it
new_connected_op = self._add_component_copy_unconnected(
original_connected_op)
new_signal.set_source(new_connected_op.output(
original_signal.source.index))
self._components_in_dfs_order.append(
new_connected_op)
# Add connected operation to queue of operations to visit
op_stack.append(original_connected_op)
# Connect output ports
for original_output_port in original_op.outputs:
for original_signal in original_output_port.signals:
# Check if the signal is one of the SFG's output signals.
if original_signal in self._original_output_signals_indexes:
# New signal already created during first step of constructor.
new_signal = self._added_components_mapping[
original_signal]
new_signal.set_source(
new_op.output(original_output_port.index))
self._components_in_dfs_order.extend(
[new_signal, new_signal.destination.operation])
# Check if signal has not been added before.
elif original_signal not in self._added_components_mapping:
if original_signal.source is None:
raise ValueError(
"Dangling signal without source in SFG")
new_signal = self._add_component_copy_unconnected(
original_signal)
new_signal.set_source(
new_op.output(original_output_port.index))
self._components_in_dfs_order.append(new_signal)
original_connected_op = original_signal.destination.operation
# Check if connected operation has been added.
if original_connected_op in self._added_components_mapping:
# Set destination to the already connected operations port
new_signal.set_destination(
self._added_components_mapping[original_connected_op].input(
original_signal.destination.index))
else:
# Create new operation, set destination to it.
new_connected_op = self._add_component_copy_unconnected(
original_connected_op)
new_signal.set_destination(new_connected_op.input(
original_signal.destination.index))
self._components_in_dfs_order.append(
new_connected_op)
# Add connected operation to the queue of operations to visist
op_stack.append(original_connected_op)
def _evaluate_source(self, src: OutputPort) -> Number:
input_values = []
for input_port in src.operation.inputs:
input_src = input_port.signals[0].source
input_values.append(self._evaluate_source(input_src))
return src.operation.evaluate_output(src.index, input_values)
...@@ -4,7 +4,7 @@ TODO: More info. ...@@ -4,7 +4,7 @@ TODO: More info.
""" """
from numbers import Number from numbers import Number
from typing import List from typing import List, Dict
class OperationState: class OperationState:
...@@ -25,11 +25,19 @@ class SimulationState: ...@@ -25,11 +25,19 @@ class SimulationState:
TODO: More info. TODO: More info.
""" """
# operation_states: Dict[OperationId, OperationState] operation_states: Dict[int, OperationState]
iteration: int iteration: int
def __init__(self): def __init__(self):
self.operation_states = {} op_state = OperationState()
self.operation_states = {1: op_state}
self.iteration = 0 self.iteration = 0
# TODO: More stuff. # @property
# #def iteration(self):
# return self.iteration
# @iteration.setter
# def iteration(self, new_iteration: int):
# self.iteration = new_iteration
#
# TODO: More stuff
"""@package docstring
B-ASIC Special Operations Module.
TODO: More info.
"""
from numbers import Number
from typing import Optional
from b_asic.operation import AbstractOperation
from b_asic.graph_component import Name, TypeName
from b_asic.port import SignalSourceProvider
class Input(AbstractOperation):
"""Input operation.
TODO: More info.
"""
def __init__(self, name: Name = ""):
super().__init__(input_count = 0, output_count = 1, name = name)
self.set_param("value", 0)
@property
def type_name(self) -> TypeName:
return "in"
def evaluate(self):
return self.param("value")
@property
def value(self) -> Number:
"""TODO: docstring"""
return self.param("value")
@value.setter
def value(self, value: Number):
"""TODO: docstring"""
self.set_param("value", value)
class Output(AbstractOperation):
"""Output operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 0, name = name, input_sources=[src0])
@property
def type_name(self) -> TypeName:
return "out"
def evaluate(self):
return None
\ No newline at end of file
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
namespace py = pybind11; namespace py = pybind11;
namespace asic { namespace asic {
int add(int a, int b) { int add(int a, int b) {
return a + b; return a + b;
} }
int sub(int a, int b) { int sub(int a, int b) {
return a - b; return a - b;
} }
} // namespace asic } // namespace asic
PYBIND11_MODULE(_b_asic, m) { PYBIND11_MODULE(_b_asic, m) {
m.doc() = "Better ASIC Toolbox Extension Module."; m.doc() = "Better ASIC Toolbox Extension Module.";
m.def("add", &asic::add, "A function which adds two numbers.", py::arg("a"), py::arg("b")); m.def("add", &asic::add, "A function which adds two numbers.", py::arg("a"), py::arg("b"));
m.def("sub", &asic::sub, "A function which subtracts two numbers.", py::arg("a"), py::arg("b")); m.def("sub", &asic::sub, "A function which subtracts two numbers.", py::arg("a"), py::arg("b"));
} }
\ No newline at end of file
...@@ -7,52 +7,24 @@ import pytest ...@@ -7,52 +7,24 @@ import pytest
def operation(): def operation():
return Constant(2) return Constant(2)
def create_operation(_type, dest_oper, index, **kwargs):
oper = _type(**kwargs)
oper_signal = Signal()
oper._output_ports[0].add_signal(oper_signal)
dest_oper._input_ports[index].add_signal(oper_signal)
return oper
@pytest.fixture @pytest.fixture
def operation_tree(): def operation_tree():
"""Return a addition operation connected with 2 constants. """Return a addition operation connected with 2 constants.
---C---+ ---C---+
---A +--A
---C---+ ---C---+
""" """
add_oper = Addition() return Addition(Constant(2), Constant(3))
create_operation(Constant, add_oper, 0, value=2)
create_operation(Constant, add_oper, 1, value=3)
return add_oper
@pytest.fixture @pytest.fixture
def large_operation_tree(): def large_operation_tree():
"""Return a constant operation connected with a large operation tree with 3 other constants and 3 additions. """Return an addition operation connected with a large operation tree with 2 other additions and 4 constants.
---C---+ ---C---+
---A---+ +--A---+
---C---+ | ---C---+ |
+---A +---A
---C---+ | ---C---+ |
---A---+ +--A---+
---C---+ ---C---+
""" """
add_oper = Addition() return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5)))
add_oper_2 = Addition()
const_oper = create_operation(Constant, add_oper, 0, value=2)
create_operation(Constant, add_oper, 1, value=3)
create_operation(Constant, add_oper_2, 0, value=4)
create_operation(Constant, add_oper_2, 1, value=5)
add_oper_3 = Addition()
add_oper_signal = Signal(add_oper.output(0), add_oper_3.output(0))
add_oper._output_ports[0].add_signal(add_oper_signal)
add_oper_3._input_ports[0].add_signal(add_oper_signal)
add_oper_2_signal = Signal(add_oper_2.output(0), add_oper_3.output(0))
add_oper_2._output_ports[0].add_signal(add_oper_2_signal)
add_oper_3._input_ports[1].add_signal(add_oper_2_signal)
return const_oper
...@@ -3,8 +3,8 @@ from b_asic.port import InputPort, OutputPort ...@@ -3,8 +3,8 @@ from b_asic.port import InputPort, OutputPort
@pytest.fixture @pytest.fixture
def input_port(): def input_port():
return InputPort(0, None) return InputPort(None, 0)
@pytest.fixture @pytest.fixture
def output_port(): def output_port():
return OutputPort(0, None) return OutputPort(None, 0)
...@@ -9,4 +9,4 @@ def signal(): ...@@ -9,4 +9,4 @@ def signal():
@pytest.fixture @pytest.fixture
def signals(): def signals():
"""Return 3 signals with no connections.""" """Return 3 signals with no connections."""
return [Signal() for _ in range(0,3)] return [Signal() for _ in range(0, 3)]
...@@ -2,226 +2,313 @@ ...@@ -2,226 +2,313 @@
B-ASIC test suite for the core operations. B-ASIC test suite for the core operations.
""" """
from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision from b_asic.core_operations import Constant, Addition, Subtraction, \
Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \
Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \
ConstantDivision, Butterfly
# Constant tests. # Constant tests.
def test_constant(): def test_constant():
constant_operation = Constant(3) constant_operation = Constant(3)
assert constant_operation.evaluate() == 3 assert constant_operation.evaluate() == 3
def test_constant_negative(): def test_constant_negative():
constant_operation = Constant(-3) constant_operation = Constant(-3)
assert constant_operation.evaluate() == -3 assert constant_operation.evaluate() == -3
def test_constant_complex(): def test_constant_complex():
constant_operation = Constant(3+4j) constant_operation = Constant(3+4j)
assert constant_operation.evaluate() == 3+4j assert constant_operation.evaluate() == 3+4j
# Addition tests. # Addition tests.
def test_addition(): def test_addition():
test_operation = Addition() test_operation = Addition()
constant_operation = Constant(3) constant_operation = Constant(3)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 8
def test_addition_negative(): def test_addition_negative():
test_operation = Addition() test_operation = Addition()
constant_operation = Constant(-3) constant_operation = Constant(-3)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -8
def test_addition_complex(): def test_addition_complex():
test_operation = Addition() test_operation = Addition()
constant_operation = Constant((3+5j)) constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j)) constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j)
# Subtraction tests. # Subtraction tests.
def test_subtraction(): def test_subtraction():
test_operation = Subtraction() test_operation = Subtraction()
constant_operation = Constant(5) constant_operation = Constant(5)
constant_operation_2 = Constant(3) constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 2
def test_subtraction_negative(): def test_subtraction_negative():
test_operation = Subtraction() test_operation = Subtraction()
constant_operation = Constant(-5) constant_operation = Constant(-5)
constant_operation_2 = Constant(-3) constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -2
def test_subtraction_complex(): def test_subtraction_complex():
test_operation = Subtraction() test_operation = Subtraction()
constant_operation = Constant((3+5j)) constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j)) constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j)
# Multiplication tests. # Multiplication tests.
def test_multiplication(): def test_multiplication():
test_operation = Multiplication() test_operation = Multiplication()
constant_operation = Constant(5) constant_operation = Constant(5)
constant_operation_2 = Constant(3) constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_negative(): def test_multiplication_negative():
test_operation = Multiplication() test_operation = Multiplication()
constant_operation = Constant(-5) constant_operation = Constant(-5)
constant_operation_2 = Constant(-3) constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_complex(): def test_multiplication_complex():
test_operation = Multiplication() test_operation = Multiplication()
constant_operation = Constant((3+5j)) constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j)) constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j)
# Division tests. # Division tests.
def test_division(): def test_division():
test_operation = Division() test_operation = Division()
constant_operation = Constant(30) constant_operation = Constant(30)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_negative(): def test_division_negative():
test_operation = Division() test_operation = Division()
constant_operation = Constant(-30) constant_operation = Constant(-30)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_complex(): def test_division_complex():
test_operation = Division() test_operation = Division()
constant_operation = Constant((60+40j)) constant_operation = Constant((60+40j))
constant_operation_2 = Constant((10+20j)) constant_operation_2 = Constant((10+20j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j)
# SquareRoot tests. # SquareRoot tests.
def test_squareroot(): def test_squareroot():
test_operation = SquareRoot() test_operation = SquareRoot()
constant_operation = Constant(36) constant_operation = Constant(36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6 assert test_operation.evaluate(constant_operation.evaluate()) == 6
def test_squareroot_negative(): def test_squareroot_negative():
test_operation = SquareRoot() test_operation = SquareRoot()
constant_operation = Constant(-36) constant_operation = Constant(-36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6j assert test_operation.evaluate(constant_operation.evaluate()) == 6j
def test_squareroot_complex(): def test_squareroot_complex():
test_operation = SquareRoot() test_operation = SquareRoot()
constant_operation = Constant((48+64j)) constant_operation = Constant((48+64j))
assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j)
# ComplexConjugate tests. # ComplexConjugate tests.
def test_complexconjugate(): def test_complexconjugate():
test_operation = ComplexConjugate() test_operation = ComplexConjugate()
constant_operation = Constant(3+4j) constant_operation = Constant(3+4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j)
def test_test_complexconjugate_negative(): def test_test_complexconjugate_negative():
test_operation = ComplexConjugate() test_operation = ComplexConjugate()
constant_operation = Constant(-3-4j) constant_operation = Constant(-3-4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j)
# Max tests. # Max tests.
def test_max(): def test_max():
test_operation = Max() test_operation = Max()
constant_operation = Constant(30) constant_operation = Constant(30)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 30
def test_max_negative(): def test_max_negative():
test_operation = Max() test_operation = Max()
constant_operation = Constant(-30) constant_operation = Constant(-30)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -5
# Min tests. # Min tests.
def test_min(): def test_min():
test_operation = Min() test_operation = Min()
constant_operation = Constant(30) constant_operation = Constant(30)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 5
def test_min_negative(): def test_min_negative():
test_operation = Min() test_operation = Min()
constant_operation = Constant(-30) constant_operation = Constant(-30)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -30
# Absolute tests. # Absolute tests.
def test_absolute(): def test_absolute():
test_operation = Absolute() test_operation = Absolute()
constant_operation = Constant(30) constant_operation = Constant(30)
assert test_operation.evaluate(constant_operation.evaluate()) == 30 assert test_operation.evaluate(constant_operation.evaluate()) == 30
def test_absolute_negative(): def test_absolute_negative():
test_operation = Absolute() test_operation = Absolute()
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == 5 assert test_operation.evaluate(constant_operation.evaluate()) == 5
def test_absolute_complex(): def test_absolute_complex():
test_operation = Absolute() test_operation = Absolute()
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 assert test_operation.evaluate(constant_operation.evaluate()) == 5.0
# ConstantMultiplication tests. # ConstantMultiplication tests.
def test_constantmultiplication(): def test_constantmultiplication():
test_operation = ConstantMultiplication(5) test_operation = ConstantMultiplication(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 100 assert test_operation.evaluate(constant_operation.evaluate()) == 100
def test_constantmultiplication_negative(): def test_constantmultiplication_negative():
test_operation = ConstantMultiplication(5) test_operation = ConstantMultiplication(5)
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -25 assert test_operation.evaluate(constant_operation.evaluate()) == -25
def test_constantmultiplication_complex(): def test_constantmultiplication_complex():
test_operation = ConstantMultiplication(3+2j) test_operation = ConstantMultiplication(3+2j)
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j)
# ConstantAddition tests. # ConstantAddition tests.
def test_constantaddition(): def test_constantaddition():
test_operation = ConstantAddition(5) test_operation = ConstantAddition(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 25 assert test_operation.evaluate(constant_operation.evaluate()) == 25
def test_constantaddition_negative(): def test_constantaddition_negative():
test_operation = ConstantAddition(4) test_operation = ConstantAddition(4)
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -1 assert test_operation.evaluate(constant_operation.evaluate()) == -1
def test_constantaddition_complex(): def test_constantaddition_complex():
test_operation = ConstantAddition(3+2j) test_operation = ConstantAddition(3+2j)
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j)
# ConstantSubtraction tests. # ConstantSubtraction tests.
def test_constantsubtraction(): def test_constantsubtraction():
test_operation = ConstantSubtraction(5) test_operation = ConstantSubtraction(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 15 assert test_operation.evaluate(constant_operation.evaluate()) == 15
def test_constantsubtraction_negative(): def test_constantsubtraction_negative():
test_operation = ConstantSubtraction(4) test_operation = ConstantSubtraction(4)
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -9 assert test_operation.evaluate(constant_operation.evaluate()) == -9
def test_constantsubtraction_complex(): def test_constantsubtraction_complex():
test_operation = ConstantSubtraction(4+6j) test_operation = ConstantSubtraction(4+6j)
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j)
# ConstantDivision tests. # ConstantDivision tests.
def test_constantdivision(): def test_constantdivision():
test_operation = ConstantDivision(5) test_operation = ConstantDivision(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 4 assert test_operation.evaluate(constant_operation.evaluate()) == 4
def test_constantdivision_negative(): def test_constantdivision_negative():
test_operation = ConstantDivision(4) test_operation = ConstantDivision(4)
constant_operation = Constant(-20) constant_operation = Constant(-20)
assert test_operation.evaluate(constant_operation.evaluate()) == -5 assert test_operation.evaluate(constant_operation.evaluate()) == -5
def test_constantdivision_complex(): def test_constantdivision_complex():
test_operation = ConstantDivision(2+2j) test_operation = ConstantDivision(2+2j)
constant_operation = Constant((10+10j)) constant_operation = Constant((10+10j))
assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j)
def test_butterfly():
test_operation = Butterfly()
assert list(test_operation.evaluate(2, 3)) == [5, -1]
def test_butterfly_negative():
test_operation = Butterfly()
assert list(test_operation.evaluate(-2, -3)) == [-5, 1]
def test_buttefly_complex():
test_operation = Butterfly()
assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j]
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
B-ASIC test suite for graph id generator. B-ASIC test suite for graph id generator.
""" """
from b_asic.graph_id import GraphIDGenerator, GraphID from b_asic.signal_flow_graph import GraphIDGenerator, GraphID
import pytest import pytest
@pytest.fixture @pytest.fixture
...@@ -12,17 +12,17 @@ def graph_id_generator(): ...@@ -12,17 +12,17 @@ def graph_id_generator():
class TestGetNextId: class TestGetNextId:
def test_empty_string_generator(self, graph_id_generator): def test_empty_string_generator(self, graph_id_generator):
"""Test the graph id generator for an empty string type.""" """Test the graph id generator for an empty string type."""
assert graph_id_generator.get_next_id("") == "1" assert graph_id_generator.next_id("") == "1"
assert graph_id_generator.get_next_id("") == "2" assert graph_id_generator.next_id("") == "2"
def test_normal_string_generator(self, graph_id_generator): def test_normal_string_generator(self, graph_id_generator):
""""Test the graph id generator for a normal string type.""" """"Test the graph id generator for a normal string type."""
assert graph_id_generator.get_next_id("add") == "add1" assert graph_id_generator.next_id("add") == "add1"
assert graph_id_generator.get_next_id("add") == "add2" assert graph_id_generator.next_id("add") == "add2"
def test_different_strings_generator(self, graph_id_generator): def test_different_strings_generator(self, graph_id_generator):
"""Test the graph id generator for different strings.""" """Test the graph id generator for different strings."""
assert graph_id_generator.get_next_id("sub") == "sub1" assert graph_id_generator.next_id("sub") == "sub1"
assert graph_id_generator.get_next_id("mul") == "mul1" assert graph_id_generator.next_id("mul") == "mul1"
assert graph_id_generator.get_next_id("sub") == "sub2" assert graph_id_generator.next_id("sub") == "sub2"
assert graph_id_generator.get_next_id("mul") == "mul2" assert graph_id_generator.next_id("mul") == "mul2"
...@@ -9,42 +9,37 @@ from b_asic import Signal ...@@ -9,42 +9,37 @@ from b_asic import Signal
@pytest.fixture @pytest.fixture
def inp_port(): def inp_port():
return InputPort(0, None) return InputPort(None, 0)
@pytest.fixture @pytest.fixture
def out_port(): def out_port():
return OutputPort(0, None) return OutputPort(None, 0)
@pytest.fixture @pytest.fixture
def out_port2(): def out_port2():
return OutputPort(1, None) return OutputPort(None, 1)
@pytest.fixture @pytest.fixture
def dangling_sig(): def dangling_sig():
return Signal() return Signal()
@pytest.fixture @pytest.fixture
def s_w_source(): def s_w_source(out_port):
out_port = OutputPort(0, None)
return Signal(source=out_port) return Signal(source=out_port)
@pytest.fixture @pytest.fixture
def sig_with_dest(): def sig_with_dest(inp_port):
inp_port = InputPort(0, None) return Signal(destination=inp_port)
return Signal(destination=out_port)
@pytest.fixture @pytest.fixture
def connected_sig(): def connected_sig(inp_port, out_port):
out_port = OutputPort(0, None)
inp_port = InputPort(0, None)
return Signal(source=out_port, destination=inp_port) return Signal(source=out_port, destination=inp_port)
def test_connect_then_disconnect(inp_port, out_port): def test_connect_then_disconnect(inp_port, out_port):
"""Test connect unused port to port.""" """Test connect unused port to port."""
s1 = inp_port.connect(out_port) s1 = inp_port.connect(out_port)
assert inp_port.connected_ports == [out_port] assert inp_port.connected_source == out_port
assert out_port.connected_ports == [inp_port]
assert inp_port.signals == [s1] assert inp_port.signals == [s1]
assert out_port.signals == [s1] assert out_port.signals == [s1]
assert s1.source is out_port assert s1.source is out_port
...@@ -52,8 +47,7 @@ def test_connect_then_disconnect(inp_port, out_port): ...@@ -52,8 +47,7 @@ def test_connect_then_disconnect(inp_port, out_port):
inp_port.remove_signal(s1) inp_port.remove_signal(s1)
assert inp_port.connected_ports == [] assert inp_port.connected_source is None
assert out_port.connected_ports == []
assert inp_port.signals == [] assert inp_port.signals == []
assert out_port.signals == [s1] assert out_port.signals == [s1]
assert s1.source is out_port assert s1.source is out_port
...@@ -62,34 +56,46 @@ def test_connect_then_disconnect(inp_port, out_port): ...@@ -62,34 +56,46 @@ def test_connect_then_disconnect(inp_port, out_port):
def test_connect_used_port_to_new_port(inp_port, out_port, out_port2): def test_connect_used_port_to_new_port(inp_port, out_port, out_port2):
"""Does connecting multiple ports to an inputport throw error?""" """Does connecting multiple ports to an inputport throw error?"""
inp_port.connect(out_port) inp_port.connect(out_port)
with pytest.raises(AssertionError): with pytest.raises(Exception):
inp_port.connect(out_port2) inp_port.connect(out_port2)
def test_add_signal_then_disconnect(inp_port, s_w_source): def test_add_signal_then_disconnect(inp_port, s_w_source):
"""Can signal be connected then disconnected properly?""" """Can signal be connected then disconnected properly?"""
inp_port.add_signal(s_w_source) inp_port.add_signal(s_w_source)
assert inp_port.connected_ports == [s_w_source.source] assert inp_port.connected_source == s_w_source.source
assert s_w_source.source.connected_ports == [inp_port]
assert inp_port.signals == [s_w_source] assert inp_port.signals == [s_w_source]
assert s_w_source.source.signals == [s_w_source] assert s_w_source.source.signals == [s_w_source]
assert s_w_source.destination is inp_port assert s_w_source.destination is inp_port
inp_port.remove_signal(s_w_source) inp_port.remove_signal(s_w_source)
assert inp_port.connected_ports == [] assert inp_port.connected_source is None
assert s_w_source.source.connected_ports == []
assert inp_port.signals == [] assert inp_port.signals == []
assert s_w_source.source.signals == [s_w_source] assert s_w_source.source.signals == [s_w_source]
assert s_w_source.destination is None assert s_w_source.destination is None
def test_connect_then_disconnect(inp_port, out_port): def test_set_value_length_pos_int(inp_port):
"""Can port be connected and then disconnected properly?""" inp_port.value_length = 10
inp_port.connect(out_port) assert inp_port.value_length == 10
def test_set_value_length_zero(inp_port):
inp_port.value_length = 0
assert inp_port.value_length == 0
def test_set_value_length_neg_int(inp_port):
with pytest.raises(Exception):
inp_port.value_length = -10
def test_set_value_length_complex(inp_port):
with pytest.raises(Exception):
inp_port.value_length = (2+4j)
inp_port.disconnect(out_port) def test_set_value_length_float(inp_port):
with pytest.raises(Exception):
inp_port.value_length = 3.2
print("outport signals:", out_port.signals, "count:", out_port.signal_count()) def test_set_value_length_pos_then_none(inp_port):
assert inp_port.signal_count() == 1 inp_port.value_length = 10
assert len(inp_port.connected_ports) == 0 inp_port.value_length = None
assert out_port.signal_count() == 0 assert inp_port.value_length is None
from b_asic.core_operations import Constant, Addition from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
import pytest import pytest
class TestTraverse: class TestTraverse:
def test_traverse_single_tree(self, operation): def test_traverse_single_tree(self, operation):
"""Traverse a tree consisting of one operation.""" """Traverse a tree consisting of one operation."""
...@@ -20,12 +21,11 @@ class TestTraverse: ...@@ -20,12 +21,11 @@ class TestTraverse:
def test_traverse_type(self, large_operation_tree): def test_traverse_type(self, large_operation_tree):
traverse = list(large_operation_tree.traverse()) traverse = list(large_operation_tree.traverse())
assert len(list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 assert len(
assert len(list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3
assert len(
list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4
def test_traverse_loop(self, operation_tree): def test_traverse_loop(self, operation_tree):
add_oper_signal = Signal() # TODO: Construct a graph that contains a loop and make sure you can traverse it properly.
operation_tree._output_ports[0].add_signal(add_oper_signal) assert True
operation_tree._input_ports[0].remove_signal(add_oper_signal)
operation_tree._input_ports[0].add_signal(add_oper_signal)
assert len(list(operation_tree.traverse())) == 2