Skip to content
Snippets Groups Projects
operation.py 11.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    B-ASIC Operation Module.
    TODO: More info.
    """
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    import collections
    
    
    from abc import abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from copy import deepcopy
    
    from numbers import Number
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union
    
    from collections import deque
    
    from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from b_asic.port import SignalSourceProvider, InputPort, OutputPort
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    class Operation(GraphComponent, SignalSourceProvider):
    
        """Operation interface.
        TODO: More info.
        """
    
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
            """Overloads the addition operator to make it return a new Addition operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantAddition operation object instead.
            """
            raise NotImplementedError
    
        @abstractmethod
        def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
            """Overloads the subtraction operator to make it return a new Subtraction operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantSubtraction operation object instead.
            """
            raise NotImplementedError
    
        @abstractmethod
        def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
            """Overloads the multiplication operator to make it return a new Multiplication operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantMultiplication operation object instead.
            """
            raise NotImplementedError
    
        @abstractmethod
        def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
            """Overloads the division operator to make it return a new Division operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantDivision operation object instead.
            """
            raise NotImplementedError
    
        @property
        @abstractmethod
        def inputs(self) -> List[InputPort]:
    
            """Get a list of all input ports."""
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def outputs(self) -> List[OutputPort]:
    
            """Get a list of all output ports."""
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        @abstractmethod
        def input_count(self) -> int:
            """Get the number of input ports."""
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        @abstractmethod
        def output_count(self) -> int:
            """Get the number of output ports."""
            raise NotImplementedError
    
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def input(self, i: int) -> InputPort:
    
            """Get the input port at index i."""
            raise NotImplementedError
    
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def output(self, i: int) -> OutputPort:
    
            """Get the output port at index i."""
            raise NotImplementedError
    
        @abstractmethod
        def params(self) -> Dict[str, Optional[Any]]:
            """Get a dictionary of all parameter values."""
            raise NotImplementedError
    
        @abstractmethod
        def param(self, name: str) -> Optional[Any]:
            """Get the value of a parameter.
            Returns None if the parameter is not defined.
            """
            raise NotImplementedError
    
        @abstractmethod
        def set_param(self, name: str, value: Any) -> None:
            """Set the value of a parameter.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            Adds the parameter if it is not already defined.
            """
            raise NotImplementedError
    
        @abstractmethod
        def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
            """Evaluate the output at index i of this operation with the given input values.
            The returned sequence contains results corresponding to each output of this operation,
            where a value of None means it was not evaluated.
            The value at index i is guaranteed to have been evaluated, while the others may or may not
            have been evaluated depending on what is the most efficient.
            For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1).
    
            """
            raise NotImplementedError
    
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def split(self) -> Iterable["Operation"]:
    
            """Split the operation into multiple operations.
            If splitting is not possible, this may return a list containing only the operation itself.
            """
            raise NotImplementedError
    
        @property
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def neighbors(self) -> Iterable["Operation"]:
    
            """Return all operations that are connected by signals to this operation.
            If no neighbors are found, this returns an empty list.
            """
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @abstractmethod
        def traverse(self) -> Generator["Operation", None, None]:
            """Get a generator that recursively iterates through all operations that are connected by signals to this operation,
            as well as the ones that they are connected to.
            """
            raise NotImplementedError
    
    
    
    class AbstractOperation(Operation, AbstractGraphComponent):
        """Generic abstract operation class which most implementations will derive from.
        TODO: More info.
        """
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        _input_ports: List[InputPort]
        _output_ports: List[OutputPort]
    
        _parameters: Dict[str, Optional[Any]]
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
    
            super().__init__(name)
            self._input_ports = []
            self._output_ports = []
            self._parameters = {}
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            # Allocate input ports.
            for i in range(input_count):
                self._input_ports.append(InputPort(self, i))
    
            # Allocate output ports.
            for i in range(output_count):
                self._output_ports.append(OutputPort(self, i))
    
            # Connect given input sources, if any.
            if input_sources is not None:
                source_count = len(input_sources)
                if source_count != input_count:
    
                    raise ValueError(
                        f"Operation expected {input_count} input sources but only got {source_count}")
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                for i, src in enumerate(input_sources):
                    if src is not None:
                        self._input_ports[i].connect(src.source)
    
    
        @abstractmethod
    
        def evaluate(self, *inputs) -> Any:  # pylint: disable=arguments-differ
    
            """Evaluate the operation and generate a list of output values given a
            list of input values.
            """
            raise NotImplementedError
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
            # Import here to avoid circular imports.
            from b_asic.core_operations import Addition, ConstantAddition
    
            if isinstance(src, Number):
                return ConstantAddition(src, self)
            return Addition(self, src)
    
        def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
            # Import here to avoid circular imports.
            from b_asic.core_operations import Subtraction, ConstantSubtraction
    
            if isinstance(src, Number):
                return ConstantSubtraction(src, self)
            return Subtraction(self, src)
    
        def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
            # Import here to avoid circular imports.
            from b_asic.core_operations import Multiplication, ConstantMultiplication
    
            if isinstance(src, Number):
                return ConstantMultiplication(src, self)
            return Multiplication(self, src)
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
            # Import here to avoid circular imports.
            from b_asic.core_operations import Division, ConstantDivision
    
            if isinstance(src, Number):
                return ConstantDivision(src, self)
            return Division(self, src)
    
        @property
        def inputs(self) -> List[InputPort]:
    
            return self._input_ports.copy()
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
        def outputs(self) -> List[OutputPort]:
    
            return self._output_ports.copy()
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        def input_count(self) -> int:
            return len(self._input_ports)
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        def output_count(self) -> int:
            return len(self._output_ports)
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def input(self, i: int) -> InputPort:
    
            return self._input_ports[i]
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def output(self, i: int) -> OutputPort:
    
            return self._output_ports[i]
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
    
        def params(self) -> Dict[str, Optional[Any]]:
            return self._parameters.copy()
    
        def param(self, name: str) -> Optional[Any]:
            return self._parameters.get(name)
    
        def set_param(self, name: str, value: Any) -> None:
            self._parameters[name] = value
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
            result = self.evaluate(*input_values)
            if isinstance(result, collections.Sequence):
                if len(result) != self.output_count:
    
                    raise RuntimeError(
                        "Operation evaluated to incorrect number of outputs")
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                return result
            if isinstance(result, Number):
                if self.output_count != 1:
    
                    raise RuntimeError(
                        "Operation evaluated to incorrect number of outputs")
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                return [result]
            raise RuntimeError("Operation evaluated to invalid type")
    
        def split(self) -> Iterable[Operation]:
            # Import here to avoid circular imports.
            from b_asic.special_operations import Input
            try:
                result = self.evaluate([Input()] * self.input_count)
                if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
                    return result
                if isinstance(result, Operation):
                    return [result]
            except TypeError:
                pass
            except ValueError:
                pass
    
            return [self]
    
        @property
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def neighbors(self) -> Iterable[Operation]:
            neighbors = []
    
            for port in self._input_ports:
                for signal in port.signals:
                    neighbors.append(signal.source.operation)
            for port in self._output_ports:
                for signal in port.signals:
                    neighbors.append(signal.destination.operation)
            return neighbors
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def traverse(self) -> Generator[Operation, None, None]:
            # Breadth first search.
            visited = {self}
    
            queue = deque([self])
            while queue:
                operation = queue.popleft()
                yield operation
                for n_operation in operation.neighbors:
                    if n_operation not in visited:
                        visited.add(n_operation)
                        queue.append(n_operation)
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
        def source(self) -> OutputPort:
            if self.output_count != 1:
                diff = "more" if self.output_count > 1 else "less"
    
                raise TypeError(
                    f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            return self.output(0)
    
        def copy_unconnected(self) -> GraphComponent:
            new_comp: AbstractOperation = super().copy_unconnected()
            for name, value in self.params.items():
    
                new_comp.set_param(name, deepcopy(
                    value))  # pylint: disable=no-member
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            return new_comp