"""@package docstring B-ASIC Operation Module. TODO: More info. """ import collections from abc import abstractmethod from numbers import Number from typing import NewType, List, Sequence, Iterable, Mapping, MutableMapping, Optional, Any, Set, Union from math import trunc from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort from b_asic.signal import Signal ResultKey = NewType("ResultKey", str) ResultMap = Mapping[ResultKey, Optional[Number]] MutableResultMap = MutableMapping[ResultKey, Optional[Number]] RegisterMap = Mapping[ResultKey, Number] MutableRegisterMap = MutableMapping[ResultKey, Number] class Operation(GraphComponent, SignalSourceProvider): """Operation interface. TODO: More info. """ @abstractmethod def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": """Overloads the addition operator to make it return a new Addition operation object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": """Overloads the addition operator to make it return a new Addition operation object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": """Overloads the subtraction operator to make it return a new Subtraction operation object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": """Overloads the subtraction operator to make it return a new Subtraction operation object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": """Overloads the multiplication operator to make it return a new Multiplication operation object that is connected to the self and other objects. If other is a number then returns a ConstantMultiplication operation object instead. """ raise NotImplementedError @abstractmethod def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": """Overloads the multiplication operator to make it return a new Multiplication operation object that is connected to the self and other objects. If other is a number then returns a ConstantMultiplication operation object instead. """ raise NotImplementedError @abstractmethod def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": """Overloads the division operator to make it return a new Division operation object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": """Overloads the division operator to make it return a new Division operation object that is connected to the self and other objects. """ raise NotImplementedError @property @abstractmethod def input_count(self) -> int: """Get the number of input ports.""" raise NotImplementedError @property @abstractmethod def output_count(self) -> int: """Get the number of output ports.""" raise NotImplementedError @abstractmethod def input(self, index: int) -> InputPort: """Get the input port at the given index.""" raise NotImplementedError @abstractmethod def output(self, index: int) -> OutputPort: """Get the output port at the given index.""" raise NotImplementedError @property @abstractmethod def inputs(self) -> Sequence[InputPort]: """Get all input ports.""" raise NotImplementedError @property @abstractmethod def outputs(self) -> Sequence[OutputPort]: """Get all output ports.""" raise NotImplementedError @property @abstractmethod def input_signals(self) -> Iterable[Signal]: """Get all the signals that are connected to this operation's input ports, in no particular order. """ raise NotImplementedError @property @abstractmethod def output_signals(self) -> Iterable[Signal]: """Get all the signals that are connected to this operation's output ports, in no particular order. """ raise NotImplementedError @abstractmethod def key(self, index: int, prefix: str = "") -> ResultKey: """Get the key used to access the result of a certain output of this operation from the results parameter passed to current_output(s) or evaluate_output(s). """ raise NotImplementedError @abstractmethod def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]: """Get the current output at the given index of this operation, if available. The registers parameter will be used for lookup. The prefix parameter will be used as a prefix for the key string when looking for registers. See also: current_outputs, evaluate_output, evaluate_outputs. """ raise NotImplementedError @abstractmethod def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: """Evaluate the output at the given index of this operation with the given input values. The results parameter will be used to store any results (including intermediate results) for caching. The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values. The prefix parameter will be used as a prefix for the key string when storing results/registers. See also: evaluate_outputs, current_output, current_outputs. """ raise NotImplementedError @abstractmethod def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]: """Get all current outputs of this operation, if available. See current_output for more information. """ raise NotImplementedError @abstractmethod def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: """Evaluate all outputs of this operation given the input values. See evaluate_output for more information. """ raise NotImplementedError @abstractmethod def split(self) -> Iterable["Operation"]: """Split the operation into multiple operations. If splitting is not possible, this may return a list containing only the operation itself. """ raise NotImplementedError class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. TODO: More info. """ _input_ports: List[InputPort] _output_ports: List[OutputPort] def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__(name) self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports. self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports. # Connect given input sources, if any. if input_sources is not None: source_count = len(input_sources) if source_count != input_count: raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @abstractmethod def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values.""" raise NotImplementedError def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. return Addition(self, Constant(src) if isinstance(src, Number) else src) def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. return Addition(Constant(src) if isinstance(src, Number) else src, self) def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. return Subtraction(self, Constant(src) if isinstance(src, Number) else src) def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. return Subtraction(Constant(src) if isinstance(src, Number) else src, self) def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src) def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self) def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. return Division(self, Constant(src) if isinstance(src, Number) else src) def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. return Division(Constant(src) if isinstance(src, Number) else src, self) @property def input_count(self) -> int: return len(self._input_ports) @property def output_count(self) -> int: return len(self._output_ports) def input(self, index: int) -> InputPort: return self._input_ports[index] def output(self, index: int) -> OutputPort: return self._output_ports[index] @property def inputs(self) -> Sequence[InputPort]: return self._input_ports @property def outputs(self) -> Sequence[OutputPort]: return self._output_ports @property def input_signals(self) -> Iterable[Signal]: result = [] for p in self.inputs: for s in p.signals: result.append(s) return result @property def output_signals(self) -> Iterable[Signal]: result = [] for p in self.outputs: for s in p.signals: result.append(s) return result def key(self, index: int, prefix: str = "") -> ResultKey: key = prefix if self.output_count != 1: if key: key += "." key += str(index) elif not key: key = str(index) return key def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]: return None def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if len(input_values) != self.input_count: raise ValueError(f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})") if results is None: results = {} if registers is None: registers = {} values = self.evaluate(*self.truncate_inputs(input_values)) if isinstance(values, collections.abc.Sequence): if len(values) != self.output_count: raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})") elif isinstance(values, Number): if self.output_count != 1: raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)") values = (values,) else: raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})") if self.output_count == 1: results[self.key(index, prefix)] = values[index] else: for i in range(self.output_count): results[self.key(i, prefix)] = values[i] return values[index] def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]: return [self.current_output(i, registers, prefix) for i in range(self.output_count)] def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)] def split(self) -> Iterable[Operation]: # Import here to avoid circular imports. from b_asic.special_operations import Input try: result = self.evaluate([Input()] * self.input_count) if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): return result if isinstance(result, Operation): return [result] except TypeError: pass except ValueError: pass return [self] @property def neighbors(self) -> Iterable[GraphComponent]: return list(self.input_signals) + list(self.output_signals) @property def source(self) -> OutputPort: if self.output_count != 1: diff = "more" if self.output_count > 1 else "less" raise TypeError( f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) def truncate_input(self, index: int, value: Number, bits: int) -> Number: """Truncate the value to be used as input at the given index to a certain bit length.""" n = value if not isinstance(n, int): n = trunc(value) return n & ((2 ** bits) - 1) def truncate_inputs(self, input_values: Sequence[Number]) -> Sequence[Number]: """Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input.""" args = [] for i, input_port in enumerate(self.inputs): if input_port.signal_count >= 1: bits = input_port.signals[0].bits if bits is None: args.append(input_values[i]) else: if isinstance(input_values[i], complex): raise TypeError("Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}") args.append(self.truncate_input(i, input_values[i], bits)) else: args.append(input_values[i]) return args