"""@package docstring B-ASIC Operation Module. TODO: More info. """ import collections from abc import abstractmethod from copy import deepcopy from numbers import Number from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union from collections import deque from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort class Operation(GraphComponent, SignalSourceProvider): """Operation interface. TODO: More info. """ @abstractmethod def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": """Overloads the addition operator to make it return a new Addition operation object that is connected to the self and other objects. If other is a number then returns a ConstantAddition operation object instead. """ raise NotImplementedError @abstractmethod def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": """Overloads the subtraction operator to make it return a new Subtraction operation object that is connected to the self and other objects. If other is a number then returns a ConstantSubtraction operation object instead. """ raise NotImplementedError @abstractmethod def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": """Overloads the multiplication operator to make it return a new Multiplication operation object that is connected to the self and other objects. If other is a number then returns a ConstantMultiplication operation object instead. """ raise NotImplementedError @abstractmethod def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": """Overloads the division operator to make it return a new Division operation object that is connected to the self and other objects. If other is a number then returns a ConstantDivision operation object instead. """ raise NotImplementedError @property @abstractmethod def inputs(self) -> List[InputPort]: """Get a list of all input ports.""" raise NotImplementedError @property @abstractmethod def outputs(self) -> List[OutputPort]: """Get a list of all output ports.""" raise NotImplementedError @property @abstractmethod def input_count(self) -> int: """Get the number of input ports.""" raise NotImplementedError @property @abstractmethod def output_count(self) -> int: """Get the number of output ports.""" raise NotImplementedError @abstractmethod def input(self, i: int) -> InputPort: """Get the input port at index i.""" raise NotImplementedError @abstractmethod def output(self, i: int) -> OutputPort: """Get the output port at index i.""" raise NotImplementedError @abstractmethod def params(self) -> Dict[str, Optional[Any]]: """Get a dictionary of all parameter values.""" raise NotImplementedError @abstractmethod def param(self, name: str) -> Optional[Any]: """Get the value of a parameter. Returns None if the parameter is not defined. """ raise NotImplementedError @abstractmethod def set_param(self, name: str, value: Any) -> None: """Set the value of a parameter. Adds the parameter if it is not already defined. """ raise NotImplementedError @abstractmethod def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: """Evaluate the output at index i of this operation with the given input values. The returned sequence contains results corresponding to each output of this operation, where a value of None means it was not evaluated. The value at index i is guaranteed to have been evaluated, while the others may or may not have been evaluated depending on what is the most efficient. For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1). """ raise NotImplementedError @abstractmethod def split(self) -> Iterable["Operation"]: """Split the operation into multiple operations. If splitting is not possible, this may return a list containing only the operation itself. """ raise NotImplementedError @property @abstractmethod def neighbors(self) -> Iterable["Operation"]: """Return all operations that are connected by signals to this operation. If no neighbors are found, this returns an empty list. """ raise NotImplementedError @abstractmethod def traverse(self) -> Generator["Operation", None, None]: """Get a generator that recursively iterates through all operations that are connected by signals to this operation, as well as the ones that they are connected to. """ raise NotImplementedError class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. TODO: More info. """ _input_ports: List[InputPort] _output_ports: List[OutputPort] _parameters: Dict[str, Optional[Any]] def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__(name) self._input_ports = [] self._output_ports = [] self._parameters = {} # Allocate input ports. for i in range(input_count): self._input_ports.append(InputPort(self, i)) # Allocate output ports. for i in range(output_count): self._output_ports.append(OutputPort(self, i)) # Connect given input sources, if any. if input_sources is not None: source_count = len(input_sources) if source_count != input_count: raise ValueError( f"Operation expected {input_count} input sources but only got {source_count}") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @abstractmethod def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values. """ raise NotImplementedError def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": # Import here to avoid circular imports. from b_asic.core_operations import Addition, ConstantAddition if isinstance(src, Number): return ConstantAddition(src, self) return Addition(self, src) def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": # Import here to avoid circular imports. from b_asic.core_operations import Subtraction, ConstantSubtraction if isinstance(src, Number): return ConstantSubtraction(src, self) return Subtraction(self, src) def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": # Import here to avoid circular imports. from b_asic.core_operations import Multiplication, ConstantMultiplication if isinstance(src, Number): return ConstantMultiplication(src, self) return Multiplication(self, src) def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": # Import here to avoid circular imports. from b_asic.core_operations import Division, ConstantDivision if isinstance(src, Number): return ConstantDivision(src, self) return Division(self, src) @property def inputs(self) -> List[InputPort]: return self._input_ports.copy() @property def outputs(self) -> List[OutputPort]: return self._output_ports.copy() @property def input_count(self) -> int: return len(self._input_ports) @property def output_count(self) -> int: return len(self._output_ports) def input(self, i: int) -> InputPort: return self._input_ports[i] def output(self, i: int) -> OutputPort: return self._output_ports[i] @property def params(self) -> Dict[str, Optional[Any]]: return self._parameters.copy() def param(self, name: str) -> Optional[Any]: return self._parameters.get(name) def set_param(self, name: str, value: Any) -> None: self._parameters[name] = value def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: result = self.evaluate(*input_values) if isinstance(result, collections.Sequence): if len(result) != self.output_count: raise RuntimeError( "Operation evaluated to incorrect number of outputs") return result if isinstance(result, Number): if self.output_count != 1: raise RuntimeError( "Operation evaluated to incorrect number of outputs") return [result] raise RuntimeError("Operation evaluated to invalid type") def split(self) -> Iterable[Operation]: # Import here to avoid circular imports. from b_asic.special_operations import Input try: result = self.evaluate([Input()] * self.input_count) if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): return result if isinstance(result, Operation): return [result] except TypeError: pass except ValueError: pass return [self] @property def neighbors(self) -> Iterable[Operation]: neighbors = [] for port in self._input_ports: for signal in port.signals: neighbors.append(signal.source.operation) for port in self._output_ports: for signal in port.signals: neighbors.append(signal.destination.operation) return neighbors def traverse(self) -> Generator[Operation, None, None]: # Breadth first search. visited = {self} queue = deque([self]) while queue: operation = queue.popleft() yield operation for n_operation in operation.neighbors: if n_operation not in visited: visited.add(n_operation) queue.append(n_operation) @property def source(self) -> OutputPort: if self.output_count != 1: diff = "more" if self.output_count > 1 else "less" raise TypeError( f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) def copy_unconnected(self) -> GraphComponent: new_comp: AbstractOperation = super().copy_unconnected() for name, value in self.params.items(): new_comp.set_param(name, deepcopy( value)) # pylint: disable=no-member return new_comp