Skip to content
Snippets Groups Projects
operation.py 8.62 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    B-ASIC Operation Module.
    TODO: More info.
    """
    
    from abc import abstractmethod
    from numbers import Number
    
    from typing import List, Dict, Optional, Any, Set, Sequence, TYPE_CHECKING
    
    from collections import deque
    
    from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
    
    if TYPE_CHECKING:
        from b_asic.port import InputPort, OutputPort
    
    
    class Operation(GraphComponent):
        """Operation interface.
        TODO: More info.
        """
    
        @abstractmethod
        def inputs(self) -> "List[InputPort]":
            """Get a list of all input ports."""
            raise NotImplementedError
    
        @abstractmethod
        def outputs(self) -> "List[OutputPort]":
            """Get a list of all output ports."""
            raise NotImplementedError
    
        @abstractmethod
        def input_count(self) -> int:
            """Get the number of input ports."""
            raise NotImplementedError
    
        @abstractmethod
        def output_count(self) -> int:
            """Get the number of output ports."""
            raise NotImplementedError
    
        @abstractmethod
        def input(self, i: int) -> "InputPort":
            """Get the input port at index i."""
            raise NotImplementedError
    
        @abstractmethod
        def output(self, i: int) -> "OutputPort":
            """Get the output port at index i."""
            raise NotImplementedError
    
    
        @abstractmethod
        def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]:
            """Evaluate the output port at the entered index with the entered input values and
            returns all output values that are calulated during the evaluation in a list."""
            raise NotImplementedError
    
    
        @abstractmethod
        def params(self) -> Dict[str, Optional[Any]]:
            """Get a dictionary of all parameter values."""
            raise NotImplementedError
    
        @abstractmethod
        def param(self, name: str) -> Optional[Any]:
            """Get the value of a parameter.
            Returns None if the parameter is not defined.
            """
            raise NotImplementedError
    
        @abstractmethod
        def set_param(self, name: str, value: Any) -> None:
            """Set the value of a parameter.
            The parameter must be defined.
            """
            raise NotImplementedError
    
        @abstractmethod
        def split(self) -> "List[Operation]":
            """Split the operation into multiple operations.
            If splitting is not possible, this may return a list containing only the operation itself.
            """
            raise NotImplementedError
    
        @property
        @abstractmethod
        def neighbors(self) -> "List[Operation]":
            """Return all operations that are connected by signals to this operation.
            If no neighbors are found, this returns an empty list.
            """
            raise NotImplementedError
    
    
    class AbstractOperation(Operation, AbstractGraphComponent):
        """Generic abstract operation class which most implementations will derive from.
        TODO: More info.
        """
    
        _input_ports: List["InputPort"]
        _output_ports: List["OutputPort"]
        _parameters: Dict[str, Optional[Any]]
    
        def __init__(self, name: Name = ""):
            super().__init__(name)
            self._input_ports = []
            self._output_ports = []
            self._parameters = {}
    
        @abstractmethod
        def evaluate(self, *inputs) -> Any:  # pylint: disable=arguments-differ
            """Evaluate the operation and generate a list of output values given a
            list of input values.
            """
            raise NotImplementedError
    
    
        def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]:
            eval_return = self.evaluate(*inputs)
            if isinstance(eval_return, Number):
                return [eval_return]
            elif isinstance(eval_return, (list, tuple)):
                return eval_return
            else:
                raise TypeError("Incorrect returned type from evaluate function.")
    
    
        def inputs(self) -> List["InputPort"]:
            return self._input_ports.copy()
    
        def outputs(self) -> List["OutputPort"]:
            return self._output_ports.copy()
    
        def input_count(self) -> int:
            return len(self._input_ports)
    
        def output_count(self) -> int:
            return len(self._output_ports)
    
        def input(self, i: int) -> "InputPort":
            return self._input_ports[i]
    
        def output(self, i: int) -> "OutputPort":
            return self._output_ports[i]
    
        def params(self) -> Dict[str, Optional[Any]]:
            return self._parameters.copy()
    
        def param(self, name: str) -> Optional[Any]:
            return self._parameters.get(name)
    
        def set_param(self, name: str, value: Any) -> None:
            assert name in self._parameters  # TODO: Error message.
            self._parameters[name] = value
    
        def split(self) -> List[Operation]:
            # TODO: Check implementation.
            results = self.evaluate(self._input_ports)
            if all(isinstance(e, Operation) for e in results):
                return results
            return [self]
    
        @property
        def neighbors(self) -> List[Operation]:
            neighbors: List[Operation] = []
            for port in self._input_ports:
                for signal in port.signals:
                    neighbors.append(signal.source.operation)
    
            for port in self._output_ports:
                for signal in port.signals:
                    neighbors.append(signal.destination.operation)
    
            return neighbors
    
        def traverse(self) -> Operation:
            """Traverse the operation tree and return a generator with start point in the operation."""
            return self._breadth_first_search()
    
        def _breadth_first_search(self) -> Operation:
            """Use breadth first search to traverse the operation tree."""
            visited: Set[Operation] = {self}
            queue = deque([self])
            while queue:
                operation = queue.popleft()
                yield operation
                for n_operation in operation.neighbors:
                    if n_operation not in visited:
                        visited.add(n_operation)
                        queue.append(n_operation)
    
        def __add__(self, other):
            """Overloads the addition operator to make it return a new Addition operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantAddition operation object instead.
            """
            # Import here to avoid circular imports.
            from b_asic.core_operations import Addition, ConstantAddition
    
            if isinstance(other, Operation):
                return Addition(self.output(0), other.output(0))
            elif isinstance(other, Number):
                return ConstantAddition(other, self.output(0))
            else:
                raise TypeError("Other type is not an Operation or a Number.")
    
        def __sub__(self, other):
            """Overloads the subtraction operator to make it return a new Subtraction operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantSubtraction operation object instead.
            """
            # Import here to avoid circular imports.
            from b_asic.core_operations import Subtraction, ConstantSubtraction
    
            if isinstance(other, Operation):
                return Subtraction(self.output(0), other.output(0))
            elif isinstance(other, Number):
                return ConstantSubtraction(other, self.output(0))
            else:
                raise TypeError("Other type is not an Operation or a Number.")
    
        def __mul__(self, other):
            """Overloads the multiplication operator to make it return a new Multiplication operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantMultiplication operation object instead.
            """
            # Import here to avoid circular imports.
            from b_asic.core_operations import Multiplication, ConstantMultiplication
    
            if isinstance(other, Operation):
                return Multiplication(self.output(0), other.output(0))
            elif isinstance(other, Number):
                return ConstantMultiplication(other, self.output(0))
            else:
                raise TypeError("Other type is not an Operation or a Number.")
    
        def __truediv__(self, other):
            """Overloads the division operator to make it return a new Division operation
            object that is connected to the self and other objects. If other is a number then
            returns a ConstantDivision operation object instead.
            """
            # Import here to avoid circular imports.
            from b_asic.core_operations import Division, ConstantDivision
    
            if isinstance(other, Operation):
                return Division(self.output(0), other.output(0))
            elif isinstance(other, Number):
                return ConstantDivision(other, self.output(0))
            else:
                raise TypeError("Other type is not an Operation or a Number.")