"""@package docstring
B-ASIC Operation Module.
TODO: More info.
"""

import collections

from abc import abstractmethod
from numbers import Number
from typing import NewType, List, Sequence, Iterable, Mapping, MutableMapping, Optional, Any, Set, Union
from math import trunc

from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.port import SignalSourceProvider, InputPort, OutputPort
from b_asic.signal import Signal

ResultKey = NewType("ResultKey", str)
ResultMap = Mapping[ResultKey, Optional[Number]]
MutableResultMap = MutableMapping[ResultKey, Optional[Number]]
RegisterMap = Mapping[ResultKey, Number]
MutableRegisterMap = MutableMapping[ResultKey, Number]

class Operation(GraphComponent, SignalSourceProvider):
    """Operation interface.
    TODO: More info.
    """

    @abstractmethod
    def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
        """Overloads the addition operator to make it return a new Addition operation
        object that is connected to the self and other objects.
        """
        raise NotImplementedError

    @abstractmethod
    def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
        """Overloads the addition operator to make it return a new Addition operation
        object that is connected to the self and other objects.
        """
        raise NotImplementedError

    @abstractmethod
    def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
        """Overloads the subtraction operator to make it return a new Subtraction operation
        object that is connected to the self and other objects.
        """
        raise NotImplementedError

    @abstractmethod
    def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
        """Overloads the subtraction operator to make it return a new Subtraction operation
        object that is connected to the self and other objects.
        """
        raise NotImplementedError

    @abstractmethod
    def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
        """Overloads the multiplication operator to make it return a new Multiplication operation
        object that is connected to the self and other objects. If other is a number then
        returns a ConstantMultiplication operation object instead.
        """
        raise NotImplementedError

    @abstractmethod
    def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
        """Overloads the multiplication operator to make it return a new Multiplication operation
        object that is connected to the self and other objects. If other is a number then
        returns a ConstantMultiplication operation object instead.
        """
        raise NotImplementedError

    @abstractmethod
    def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
        """Overloads the division operator to make it return a new Division operation
        object that is connected to the self and other objects.
        """
        raise NotImplementedError

    @abstractmethod
    def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
        """Overloads the division operator to make it return a new Division operation
        object that is connected to the self and other objects.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def input_count(self) -> int:
        """Get the number of input ports."""
        raise NotImplementedError

    @property
    @abstractmethod
    def output_count(self) -> int:
        """Get the number of output ports."""
        raise NotImplementedError

    @abstractmethod
    def input(self, index: int) -> InputPort:
        """Get the input port at the given index."""
        raise NotImplementedError

    @abstractmethod
    def output(self, index: int) -> OutputPort:
        """Get the output port at the given index."""
        raise NotImplementedError

    @property
    @abstractmethod
    def inputs(self) -> Sequence[InputPort]:
        """Get all input ports."""
        raise NotImplementedError

    @property
    @abstractmethod
    def outputs(self) -> Sequence[OutputPort]:
        """Get all output ports."""
        raise NotImplementedError

    @property
    @abstractmethod
    def input_signals(self) -> Iterable[Signal]:
        """Get all the signals that are connected to this operation's input ports,
        in no particular order.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def output_signals(self) -> Iterable[Signal]:
        """Get all the signals that are connected to this operation's output ports,
        in no particular order.
        """
        raise NotImplementedError

    @abstractmethod
    def key(self, index: int, prefix: str = "") -> ResultKey:
        """Get the key used to access the result of a certain output of this operation
        from the results parameter passed to current_output(s) or evaluate_output(s).
        """
        raise NotImplementedError

    @abstractmethod
    def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]:
        """Get the current output at the given index of this operation, if available.
        The registers parameter will be used for lookup.
        The prefix parameter will be used as a prefix for the key string when looking for registers.
        See also: current_outputs, evaluate_output, evaluate_outputs.
        """
        raise NotImplementedError

    @abstractmethod
    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
        """Evaluate the output at the given index of this operation with the given input values.
        The results parameter will be used to store any results (including intermediate results) for caching.
        The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values.
        The prefix parameter will be used as a prefix for the key string when storing results/registers.
        See also: evaluate_outputs, current_output, current_outputs.
        """
        raise NotImplementedError

    @abstractmethod
    def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]:
        """Get all current outputs of this operation, if available.
        See current_output for more information.
        """
        raise NotImplementedError

    @abstractmethod
    def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]:
        """Evaluate all outputs of this operation given the input values.
        See evaluate_output for more information.
        """
        raise NotImplementedError

    @abstractmethod
    def split(self) -> Iterable["Operation"]:
        """Split the operation into multiple operations.
        If splitting is not possible, this may return a list containing only the operation itself.
        """
        raise NotImplementedError

    @abstractmethod
    def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
        """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
        raise NotImplementedError


class AbstractOperation(Operation, AbstractGraphComponent):
    """Generic abstract operation class which most implementations will derive from.
    TODO: More info.
    """

    _input_ports: List[InputPort]
    _output_ports: List[OutputPort]

    def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
        super().__init__(name)
        self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports.
        self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports.

        # Connect given input sources, if any.
        if input_sources is not None:
            source_count = len(input_sources)
            if source_count != input_count:
                raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})")
            for i, src in enumerate(input_sources):
                if src is not None:
                    self._input_ports[i].connect(src.source)

    @abstractmethod
    def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
        """Evaluate the operation and generate a list of output values given a list of input values."""
        raise NotImplementedError

    def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
        from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports.
        return Addition(self, Constant(src) if isinstance(src, Number) else src)
    
    def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
        from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports.
        return Addition(Constant(src) if isinstance(src, Number) else src, self)

    def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
        from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports.
        return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
    
    def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
        from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports.
        return Subtraction(Constant(src) if isinstance(src, Number) else src, self)

    def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
        from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports.
        return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src)
    
    def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
        from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports.
        return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self)

    def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
        from b_asic.core_operations import Constant, Division # Import here to avoid circular imports.
        return Division(self, Constant(src) if isinstance(src, Number) else src)
    
    def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
        from b_asic.core_operations import Constant, Division # Import here to avoid circular imports.
        return Division(Constant(src) if isinstance(src, Number) else src, self)
    
    @property
    def input_count(self) -> int:
        return len(self._input_ports)

    @property
    def output_count(self) -> int:
        return len(self._output_ports)

    def input(self, index: int) -> InputPort:
        return self._input_ports[index]

    def output(self, index: int) -> OutputPort:
        return self._output_ports[index]

    @property
    def inputs(self) -> Sequence[InputPort]:
        return self._input_ports

    @property
    def outputs(self) -> Sequence[OutputPort]:
        return self._output_ports

    @property
    def input_signals(self) -> Iterable[Signal]:
        result = []
        for p in self.inputs:
            for s in p.signals:
                result.append(s)
        return result

    @property
    def output_signals(self) -> Iterable[Signal]:
        result = []
        for p in self.outputs:
            for s in p.signals:
                result.append(s)
        return result
    
    def key(self, index: int, prefix: str = "") -> ResultKey:
        key = prefix
        if self.output_count != 1:
            if key:
                key += "."
            key += str(index)
        elif not key:
            key = str(index)
        return key
    
    def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]:
        return None

    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
        if index < 0 or index >= self.output_count:
            raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
        if len(input_values) != self.input_count:
            raise ValueError(f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})")
        if results is None:
            results = {}
        if registers is None:
            registers = {}

        values = self.evaluate(*self.truncate_inputs(input_values))
        if isinstance(values, collections.abc.Sequence):
            if len(values) != self.output_count:
                raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})")
        elif isinstance(values, Number):
            if self.output_count != 1:
                raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)")
            values = (values,)
        else:
            raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})")

        if self.output_count == 1:
            results[self.key(index, prefix)] = values[index]
        else:
            for i in range(self.output_count):
                results[self.key(i, prefix)] = values[i]
        return values[index]

    def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]:
        return [self.current_output(i, registers, prefix) for i in range(self.output_count)]

    def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]:
        return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)]

    def split(self) -> Iterable[Operation]:
        # Import here to avoid circular imports.
        from b_asic.special_operations import Input
        try:
            result = self.evaluate([Input()] * self.input_count)
            if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
                return result
            if isinstance(result, Operation):
                return [result]
        except TypeError:
            pass
        except ValueError:
            pass
        return [self]

    def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
        if output_index < 0 or output_index >= self.output_count:
            raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
        return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs.

    @property
    def neighbors(self) -> Iterable[GraphComponent]:
        return list(self.input_signals) + list(self.output_signals)

    @property
    def source(self) -> OutputPort:
        if self.output_count != 1:
            diff = "more" if self.output_count > 1 else "less"
            raise TypeError(
                f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
        return self.output(0)
    
    def truncate_input(self, index: int, value: Number, bits: int) -> Number:
        """Truncate the value to be used as input at the given index to a certain bit length."""
        n = value
        if not isinstance(n, int):
            n = trunc(value)
        return n & ((2 ** bits) - 1)

    def truncate_inputs(self, input_values: Sequence[Number]) -> Sequence[Number]:
        """Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input."""
        args = []
        for i, input_port in enumerate(self.inputs):
            if input_port.signal_count >= 1:
                bits = input_port.signals[0].bits
                if bits is None:
                    args.append(input_values[i])
                else:
                    if isinstance(input_values[i], complex):
                        raise TypeError("Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}")
                    args.append(self.truncate_input(i, input_values[i], bits))
            else:
                args.append(input_values[i])
        return args