Skip to content
Snippets Groups Projects
operation.py 16.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    B-ASIC Operation Module.
    TODO: More info.
    """
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    import collections
    
    
    from abc import abstractmethod
    from numbers import Number
    
    from typing import NewType, List, Sequence, Iterable, Mapping, MutableMapping, Optional, Any, Set, Union
    from math import trunc
    
    
    from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from b_asic.port import SignalSourceProvider, InputPort, OutputPort
    
    ResultKey = NewType("ResultKey", str)
    ResultMap = Mapping[ResultKey, Optional[Number]]
    MutableResultMap = MutableMapping[ResultKey, Optional[Number]]
    RegisterMap = Mapping[ResultKey, Number]
    MutableRegisterMap = MutableMapping[ResultKey, Number]
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    class Operation(GraphComponent, SignalSourceProvider):
    
        """Operation interface.
        TODO: More info.
        """
    
        @abstractmethod
    
        def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """Overloads the addition operator to make it return a new Addition operation
    
            object that is connected to the self and other objects.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """
            raise NotImplementedError
    
        @abstractmethod
    
        def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
            """Overloads the addition operator to make it return a new Addition operation
            object that is connected to the self and other objects.
            """
            raise NotImplementedError
    
        @abstractmethod
        def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """Overloads the subtraction operator to make it return a new Subtraction operation
    
            object that is connected to the self and other objects.
            """
            raise NotImplementedError
    
        @abstractmethod
        def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
            """Overloads the subtraction operator to make it return a new Subtraction operation
            object that is connected to the self and other objects.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """
            raise NotImplementedError
    
        @abstractmethod
        def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
            """Overloads the multiplication operator to make it return a new Multiplication operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantMultiplication operation object instead.
            """
            raise NotImplementedError
    
        @abstractmethod
    
        def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
            """Overloads the multiplication operator to make it return a new Multiplication operation
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            object that is connected to the self and other objects. If other is a number then
    
            returns a ConstantMultiplication operation object instead.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """
            raise NotImplementedError
    
        @abstractmethod
    
        def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
            """Overloads the division operator to make it return a new Division operation
            object that is connected to the self and other objects.
            """
    
            raise NotImplementedError
    
        @abstractmethod
    
        def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
            """Overloads the division operator to make it return a new Division operation
            object that is connected to the self and other objects.
            """
    
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        @abstractmethod
        def input_count(self) -> int:
            """Get the number of input ports."""
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        @abstractmethod
        def output_count(self) -> int:
            """Get the number of output ports."""
            raise NotImplementedError
    
        @abstractmethod
    
        def input(self, index: int) -> InputPort:
            """Get the input port at the given index."""
    
            raise NotImplementedError
    
        @abstractmethod
    
        def output(self, index: int) -> OutputPort:
            """Get the output port at the given index."""
    
            raise NotImplementedError
    
    
        @abstractmethod
    
        def inputs(self) -> Sequence[InputPort]:
            """Get all input ports."""
    
            raise NotImplementedError
    
    
        @abstractmethod
    
        def outputs(self) -> Sequence[OutputPort]:
            """Get all output ports."""
            raise NotImplementedError
    
        @property
        @abstractmethod
        def input_signals(self) -> Iterable[Signal]:
            """Get all the signals that are connected to this operation's input ports,
            in no particular order.
    
            """
            raise NotImplementedError
    
    
        @abstractmethod
    
        def output_signals(self) -> Iterable[Signal]:
            """Get all the signals that are connected to this operation's output ports,
            in no particular order.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """
            raise NotImplementedError
    
        @abstractmethod
    
        def key(self, index: int, prefix: str = "") -> ResultKey:
            """Get the key used to access the result of a certain output of this operation
            from the results parameter passed to current_output(s) or evaluate_output(s).
    
            """
            raise NotImplementedError
    
        @abstractmethod
    
        def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]:
            """Get the current output at the given index of this operation, if available.
            The registers parameter will be used for lookup.
            The prefix parameter will be used as a prefix for the key string when looking for registers.
            See also: current_outputs, evaluate_output, evaluate_outputs.
            """
            raise NotImplementedError
    
        @abstractmethod
        def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
            """Evaluate the output at the given index of this operation with the given input values.
            The results parameter will be used to store any results (including intermediate results) for caching.
            The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values.
            The prefix parameter will be used as a prefix for the key string when storing results/registers.
            See also: evaluate_outputs, current_output, current_outputs.
    
            """
            raise NotImplementedError
    
        @abstractmethod
    
        def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]:
            """Get all current outputs of this operation, if available.
            See current_output for more information.
    
            """
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @abstractmethod
    
        def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]:
            """Evaluate all outputs of this operation given the input values.
            See evaluate_output for more information.
            """
            raise NotImplementedError
    
        @abstractmethod
        def split(self) -> Iterable["Operation"]:
            """Split the operation into multiple operations.
            If splitting is not possible, this may return a list containing only the operation itself.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            """
            raise NotImplementedError
    
    
    class AbstractOperation(Operation, AbstractGraphComponent):
        """Generic abstract operation class which most implementations will derive from.
        TODO: More info.
        """
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        _input_ports: List[InputPort]
        _output_ports: List[OutputPort]
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
    
            super().__init__(name)
    
            self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports.
            self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    
            # Connect given input sources, if any.
            if input_sources is not None:
                source_count = len(input_sources)
                if source_count != input_count:
    
                    raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})")
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                for i, src in enumerate(input_sources):
                    if src is not None:
                        self._input_ports[i].connect(src.source)
    
    
        @abstractmethod
    
        def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
            """Evaluate the operation and generate a list of output values given a list of input values."""
    
            raise NotImplementedError
    
    
        def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
            from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports.
            return Addition(self, Constant(src) if isinstance(src, Number) else src)
        
        def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
            from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports.
            return Addition(Constant(src) if isinstance(src, Number) else src, self)
    
        def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
            from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports.
            return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
        
        def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
            from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports.
            return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    
        def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
    
            from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports.
            return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src)
        
        def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
            from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports.
            return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self)
    
        def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
            from b_asic.core_operations import Constant, Division # Import here to avoid circular imports.
            return Division(self, Constant(src) if isinstance(src, Number) else src)
        
        def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
            from b_asic.core_operations import Constant, Division # Import here to avoid circular imports.
            return Division(Constant(src) if isinstance(src, Number) else src, self)
        
        @property
        def input_count(self) -> int:
            return len(self._input_ports)
    
        @property
        def output_count(self) -> int:
            return len(self._output_ports)
    
        def input(self, index: int) -> InputPort:
            return self._input_ports[index]
    
        def output(self, index: int) -> OutputPort:
            return self._output_ports[index]
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    
        @property
    
        def inputs(self) -> Sequence[InputPort]:
            return self._input_ports
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        def outputs(self) -> Sequence[OutputPort]:
            return self._output_ports
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        def input_signals(self) -> Iterable[Signal]:
            result = []
            for p in self.inputs:
                for s in p.signals:
                    result.append(s)
            return result
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        def output_signals(self) -> Iterable[Signal]:
            result = []
            for p in self.outputs:
                for s in p.signals:
                    result.append(s)
            return result
        
        def key(self, index: int, prefix: str = "") -> ResultKey:
            key = prefix
            if self.output_count != 1:
                if key:
                    key += "."
                key += str(index)
            elif not key:
                key = str(index)
            return key
        
        def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]:
            return None
    
        def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
            if index < 0 or index >= self.output_count:
                raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
            if len(input_values) != self.input_count:
                raise ValueError(f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})")
            if results is None:
                results = {}
            if registers is None:
                registers = {}
    
            values = self.evaluate(*self.truncate_inputs(input_values))
            if isinstance(values, collections.abc.Sequence):
                if len(values) != self.output_count:
                    raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})")
            elif isinstance(values, Number):
                if self.output_count != 1:
                    raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)")
                values = (values,)
            else:
                raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})")
    
            if self.output_count == 1:
                results[self.key(index, prefix)] = values[index]
            else:
                for i in range(self.output_count):
                    results[self.key(i, prefix)] = values[i]
            return values[index]
    
        def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]:
            return [self.current_output(i, registers, prefix) for i in range(self.output_count)]
    
        def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]:
            return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)]
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    
        def split(self) -> Iterable[Operation]:
            # Import here to avoid circular imports.
            from b_asic.special_operations import Input
            try:
    
                result = self.evaluate(*[Input()] * self.input_count)
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
                    return result
                if isinstance(result, Operation):
                    return [result]
            except TypeError:
                pass
            except ValueError:
                pass
    
            return [self]           
    
    
        @property
    
        def neighbors(self) -> Iterable[GraphComponent]:
            return list(self.input_signals) + list(self.output_signals)
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
        def source(self) -> OutputPort:
            if self.output_count != 1:
                diff = "more" if self.output_count > 1 else "less"
    
                raise TypeError(
                    f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            return self.output(0)
    
        
        def truncate_input(self, index: int, value: Number, bits: int) -> Number:
            """Truncate the value to be used as input at the given index to a certain bit length."""
            n = value
            if not isinstance(n, int):
                n = trunc(value)
            return n & ((2 ** bits) - 1)
    
        def truncate_inputs(self, input_values: Sequence[Number]) -> Sequence[Number]:
            """Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input."""
            args = []
            for i, input_port in enumerate(self.inputs):
                if input_port.signal_count >= 1:
                    bits = input_port.signals[0].bits
                    if bits is None:
                        args.append(input_values[i])
                    else:
                        if isinstance(input_values[i], complex):
                            raise TypeError("Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}")
                        args.append(self.truncate_input(i, input_values[i], bits))
                else:
                    args.append(input_values[i])
            return args