Newer
Older
"""@package docstring
B-ASIC Operation Module.
TODO: More info.
"""
from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union
from collections import deque
from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.port import SignalSourceProvider, InputPort, OutputPort
class Operation(GraphComponent, SignalSourceProvider):
"""Operation interface.
TODO: More info.
"""
@abstractmethod
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
"""Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantAddition operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
"""Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantSubtraction operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
"""Overloads the multiplication operator to make it return a new Multiplication operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantMultiplication operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantDivision operation object instead.
"""
raise NotImplementedError
@property
@abstractmethod
def inputs(self) -> List[InputPort]:
"""Get a list of all input ports."""
raise NotImplementedError
"""Get a list of all output ports."""
raise NotImplementedError
@abstractmethod
def input_count(self) -> int:
"""Get the number of input ports."""
raise NotImplementedError
@abstractmethod
def output_count(self) -> int:
"""Get the number of output ports."""
raise NotImplementedError
@abstractmethod
"""Get the input port at index i."""
raise NotImplementedError
@abstractmethod
"""Get the output port at index i."""
raise NotImplementedError
@abstractmethod
def params(self) -> Dict[str, Optional[Any]]:
"""Get a dictionary of all parameter values."""
raise NotImplementedError
@abstractmethod
def param(self, name: str) -> Optional[Any]:
"""Get the value of a parameter.
Returns None if the parameter is not defined.
"""
raise NotImplementedError
@abstractmethod
def set_param(self, name: str, value: Any) -> None:
"""Set the value of a parameter.
Adds the parameter if it is not already defined.
"""
raise NotImplementedError
@abstractmethod
def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
"""Evaluate the output at index i of this operation with the given input values.
The returned sequence contains results corresponding to each output of this operation,
where a value of None means it was not evaluated.
The value at index i is guaranteed to have been evaluated, while the others may or may not
have been evaluated depending on what is the most efficient.
For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1).
"""
raise NotImplementedError
@abstractmethod
"""Split the operation into multiple operations.
If splitting is not possible, this may return a list containing only the operation itself.
"""
raise NotImplementedError
@property
@abstractmethod
"""Return all operations that are connected by signals to this operation.
If no neighbors are found, this returns an empty list.
"""
raise NotImplementedError
@abstractmethod
def traverse(self) -> Generator["Operation", None, None]:
"""Get a generator that recursively iterates through all operations that are connected by signals to this operation,
as well as the ones that they are connected to.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from.
TODO: More info.
"""
_input_ports: List[InputPort]
_output_ports: List[OutputPort]
def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
super().__init__(name)
self._input_ports = []
self._output_ports = []
self._parameters = {}
# Allocate input ports.
for i in range(input_count):
self._input_ports.append(InputPort(self, i))
# Allocate output ports.
for i in range(output_count):
self._output_ports.append(OutputPort(self, i))
# Connect given input sources, if any.
if input_sources is not None:
source_count = len(input_sources)
if source_count != input_count:

Adam Jakobsson
committed
raise ValueError(
f"Operation expected {input_count} input sources but only got {source_count}")
for i, src in enumerate(input_sources):
if src is not None:
self._input_ports[i].connect(src.source)

Adam Jakobsson
committed
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a
list of input values.
"""
raise NotImplementedError
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, ConstantAddition
if isinstance(src, Number):
return ConstantAddition(src, self)
return Addition(self, src)
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
# Import here to avoid circular imports.
from b_asic.core_operations import Subtraction, ConstantSubtraction
if isinstance(src, Number):
return ConstantSubtraction(src, self)
return Subtraction(self, src)
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
# Import here to avoid circular imports.
from b_asic.core_operations import Multiplication, ConstantMultiplication
if isinstance(src, Number):
return ConstantMultiplication(src, self)
return Multiplication(self, src)
Angus Lothian
committed
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
# Import here to avoid circular imports.
from b_asic.core_operations import Division, ConstantDivision
if isinstance(src, Number):
return ConstantDivision(src, self)
return Division(self, src)
@property
def inputs(self) -> List[InputPort]:
@property
def outputs(self) -> List[OutputPort]:
def input_count(self) -> int:
return len(self._input_ports)
def output_count(self) -> int:
return len(self._output_ports)
def params(self) -> Dict[str, Optional[Any]]:
return self._parameters.copy()
def param(self, name: str) -> Optional[Any]:
return self._parameters.get(name)
def set_param(self, name: str, value: Any) -> None:
self._parameters[name] = value
def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
result = self.evaluate(*input_values)
if isinstance(result, collections.Sequence):
if len(result) != self.output_count:

Adam Jakobsson
committed
raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return result
if isinstance(result, Number):
if self.output_count != 1:

Adam Jakobsson
committed
raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return [result]
raise RuntimeError("Operation evaluated to invalid type")
def split(self) -> Iterable[Operation]:
# Import here to avoid circular imports.
from b_asic.special_operations import Input
try:
result = self.evaluate([Input()] * self.input_count)
if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
return result
if isinstance(result, Operation):
return [result]
except TypeError:
pass
except ValueError:
pass
def neighbors(self) -> Iterable[Operation]:
neighbors = []
for port in self._input_ports:
for signal in port.signals:
neighbors.append(signal.source.operation)
for port in self._output_ports:
for signal in port.signals:
neighbors.append(signal.destination.operation)
return neighbors
def traverse(self) -> Generator[Operation, None, None]:
# Breadth first search.
visited = {self}
queue = deque([self])
while queue:
operation = queue.popleft()
yield operation
for n_operation in operation.neighbors:
if n_operation not in visited:
visited.add(n_operation)
queue.append(n_operation)
@property
def source(self) -> OutputPort:
if self.output_count != 1:
diff = "more" if self.output_count > 1 else "less"

Adam Jakobsson
committed
raise TypeError(
f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
return self.output(0)
def copy_unconnected(self) -> GraphComponent:
new_comp: AbstractOperation = super().copy_unconnected()
for name, value in self.params.items():

Adam Jakobsson
committed
new_comp.set_param(name, deepcopy(
value)) # pylint: disable=no-member