"""@package docstring
B-ASIC Operation Module.
TODO: More info.
"""

import collections

from abc import abstractmethod
from copy import deepcopy
from numbers import Number
from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union
from collections import deque

from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.port import SignalSourceProvider, InputPort, OutputPort


class Operation(GraphComponent, SignalSourceProvider):
    """Operation interface.
    TODO: More info.
    """

    @abstractmethod
    def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
        """Overloads the addition operator to make it return a new Addition operation
        object that is connected to the self and other objects. If other is a number then
        returns a ConstantAddition operation object instead.
        """
        raise NotImplementedError

    @abstractmethod
    def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
        """Overloads the subtraction operator to make it return a new Subtraction operation
        object that is connected to the self and other objects. If other is a number then
        returns a ConstantSubtraction operation object instead.
        """
        raise NotImplementedError

    @abstractmethod
    def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
        """Overloads the multiplication operator to make it return a new Multiplication operation
        object that is connected to the self and other objects. If other is a number then
        returns a ConstantMultiplication operation object instead.
        """
        raise NotImplementedError

    @abstractmethod
    def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
        """Overloads the division operator to make it return a new Division operation
        object that is connected to the self and other objects. If other is a number then
        returns a ConstantDivision operation object instead.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def inputs(self) -> List[InputPort]:
        """Get a list of all input ports."""
        raise NotImplementedError

    @property
    @abstractmethod
    def outputs(self) -> List[OutputPort]:
        """Get a list of all output ports."""
        raise NotImplementedError

    @property
    @abstractmethod
    def input_count(self) -> int:
        """Get the number of input ports."""
        raise NotImplementedError

    @property
    @abstractmethod
    def output_count(self) -> int:
        """Get the number of output ports."""
        raise NotImplementedError

    @abstractmethod
    def input(self, i: int) -> InputPort:
        """Get the input port at index i."""
        raise NotImplementedError

    @abstractmethod
    def output(self, i: int) -> OutputPort:
        """Get the output port at index i."""
        raise NotImplementedError

    @abstractmethod
    def params(self) -> Dict[str, Optional[Any]]:
        """Get a dictionary of all parameter values."""
        raise NotImplementedError

    @abstractmethod
    def param(self, name: str) -> Optional[Any]:
        """Get the value of a parameter.
        Returns None if the parameter is not defined.
        """
        raise NotImplementedError

    @abstractmethod
    def set_param(self, name: str, value: Any) -> None:
        """Set the value of a parameter.
        Adds the parameter if it is not already defined.
        """
        raise NotImplementedError

    @abstractmethod
    def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
        """Evaluate the output at index i of this operation with the given input values.
        The returned sequence contains results corresponding to each output of this operation,
        where a value of None means it was not evaluated.
        The value at index i is guaranteed to have been evaluated, while the others may or may not
        have been evaluated depending on what is the most efficient.
        For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1).
        """
        raise NotImplementedError

    @abstractmethod
    def split(self) -> Iterable["Operation"]:
        """Split the operation into multiple operations.
        If splitting is not possible, this may return a list containing only the operation itself.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def neighbors(self) -> Iterable["Operation"]:
        """Return all operations that are connected by signals to this operation.
        If no neighbors are found, this returns an empty list.
        """
        raise NotImplementedError

    @abstractmethod
    def traverse(self) -> Generator["Operation", None, None]:
        """Get a generator that recursively iterates through all operations that are connected by signals to this operation,
        as well as the ones that they are connected to.
        """
        raise NotImplementedError


class AbstractOperation(Operation, AbstractGraphComponent):
    """Generic abstract operation class which most implementations will derive from.
    TODO: More info.
    """

    _input_ports: List[InputPort]
    _output_ports: List[OutputPort]
    _parameters: Dict[str, Optional[Any]]

    def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
        super().__init__(name)
        self._input_ports = []
        self._output_ports = []
        self._parameters = {}

        # Allocate input ports.
        for i in range(input_count):
            self._input_ports.append(InputPort(self, i))

        # Allocate output ports.
        for i in range(output_count):
            self._output_ports.append(OutputPort(self, i))

        # Connect given input sources, if any.
        if input_sources is not None:
            source_count = len(input_sources)
            if source_count != input_count:
                raise ValueError(
                    f"Operation expected {input_count} input sources but only got {source_count}")
            for i, src in enumerate(input_sources):
                if src is not None:
                    self._input_ports[i].connect(src.source)

    @abstractmethod
    def evaluate(self, *inputs) -> Any:  # pylint: disable=arguments-differ
        """Evaluate the operation and generate a list of output values given a
        list of input values.
        """
        raise NotImplementedError

    def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
        # Import here to avoid circular imports.
        from b_asic.core_operations import Addition, ConstantAddition

        if isinstance(src, Number):
            return ConstantAddition(src, self)
        return Addition(self, src)

    def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
        # Import here to avoid circular imports.
        from b_asic.core_operations import Subtraction, ConstantSubtraction

        if isinstance(src, Number):
            return ConstantSubtraction(src, self)
        return Subtraction(self, src)

    def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
        # Import here to avoid circular imports.
        from b_asic.core_operations import Multiplication, ConstantMultiplication

        if isinstance(src, Number):
            return ConstantMultiplication(src, self)
        return Multiplication(self, src)

    def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
        # Import here to avoid circular imports.
        from b_asic.core_operations import Division, ConstantDivision

        if isinstance(src, Number):
            return ConstantDivision(src, self)
        return Division(self, src)

    @property
    def inputs(self) -> List[InputPort]:
        return self._input_ports.copy()

    @property
    def outputs(self) -> List[OutputPort]:
        return self._output_ports.copy()

    @property
    def input_count(self) -> int:
        return len(self._input_ports)

    @property
    def output_count(self) -> int:
        return len(self._output_ports)

    def input(self, i: int) -> InputPort:
        return self._input_ports[i]

    def output(self, i: int) -> OutputPort:
        return self._output_ports[i]

    @property
    def params(self) -> Dict[str, Optional[Any]]:
        return self._parameters.copy()

    def param(self, name: str) -> Optional[Any]:
        return self._parameters.get(name)

    def set_param(self, name: str, value: Any) -> None:
        self._parameters[name] = value

    def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
        result = self.evaluate(*input_values)
        if isinstance(result, collections.Sequence):
            if len(result) != self.output_count:
                raise RuntimeError(
                    "Operation evaluated to incorrect number of outputs")
            return result
        if isinstance(result, Number):
            if self.output_count != 1:
                raise RuntimeError(
                    "Operation evaluated to incorrect number of outputs")
            return [result]
        raise RuntimeError("Operation evaluated to invalid type")

    def split(self) -> Iterable[Operation]:
        # Import here to avoid circular imports.
        from b_asic.special_operations import Input
        try:
            result = self.evaluate([Input()] * self.input_count)
            if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
                return result
            if isinstance(result, Operation):
                return [result]
        except TypeError:
            pass
        except ValueError:
            pass
        return [self]

    @property
    def neighbors(self) -> Iterable[Operation]:
        neighbors = []
        for port in self._input_ports:
            for signal in port.signals:
                neighbors.append(signal.source.operation)
        for port in self._output_ports:
            for signal in port.signals:
                neighbors.append(signal.destination.operation)
        return neighbors

    def traverse(self) -> Generator[Operation, None, None]:
        # Breadth first search.
        visited = {self}
        queue = deque([self])
        while queue:
            operation = queue.popleft()
            yield operation
            for n_operation in operation.neighbors:
                if n_operation not in visited:
                    visited.add(n_operation)
                    queue.append(n_operation)

    @property
    def source(self) -> OutputPort:
        if self.output_count != 1:
            diff = "more" if self.output_count > 1 else "less"
            raise TypeError(
                f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
        return self.output(0)

    def copy_unconnected(self) -> GraphComponent:
        new_comp: AbstractOperation = super().copy_unconnected()
        for name, value in self.params.items():
            new_comp.set_param(name, deepcopy(
                value))  # pylint: disable=no-member
        return new_comp