"""@package docstring
B-ASIC Abstract Operation Module.
TODO: More info.
"""

from abc import abstractmethod
from typing import List, Set, Dict, Optional, Any
from numbers import Number

from b_asic.port import InputPort, OutputPort
from b_asic.signal import Signal
from b_asic.operation import Operation
from b_asic.simulation import SimulationState, OperationState
from b_asic.utilities import breadth_first_search
from b_asic.abstract_graph_component import AbstractGraphComponent

class AbstractOperation(Operation, AbstractGraphComponent):
    """Generic abstract operation class which most implementations will derive from.
    TODO: More info.
    """

    _input_ports: List[InputPort]
    _output_ports: List[OutputPort]
    _parameters: Dict[str, Optional[Any]]

    def __init__(self, **kwds):
        super().__init__(**kwds)
        self._input_ports = []
        self._output_ports = []
        self._parameters = {}

    @abstractmethod
    def evaluate(self, inputs: list) -> list:
        """Evaluate the operation and generate a list of output values given a list of input values."""
        raise NotImplementedError

    def inputs(self) -> List[InputPort]:
        return self._input_ports.copy()

    def outputs(self) -> List[OutputPort]:
        return self._output_ports.copy()

    def input_count(self) -> int:
        return len(self._input_ports)

    def output_count(self) -> int:
        return len(self._output_ports)

    def input(self, i: int) -> InputPort:
        return self._input_ports[i]

    def output(self, i: int) -> OutputPort:
        return self._output_ports[i]

    def params(self) -> Dict[str, Optional[Any]]:
        return self._parameters.copy()

    def param(self, name: str) -> Optional[Any]:
        return self._parameters.get(name)

    def set_param(self, name: str, value: Any) -> None:
        assert name in self._parameters # TODO: Error message.
        self._parameters[name] = value

    def evaluate_outputs(self, state: SimulationState) -> List[Number]:
        # TODO: Check implementation.
        input_count: int = self.input_count()
        output_count: int = self.output_count()
        assert input_count == len(self._input_ports) # TODO: Error message.
        assert output_count == len(self._output_ports) # TODO: Error message.

        self_state: OperationState = state.operation_states[self]

        while self_state.iteration < state.iteration:
            input_values: List[Number] = [0] * input_count
            for i in range(input_count):
                source: Signal = self._input_ports[i].signal
                input_values[i] = source.operation.evaluate_outputs(state)[source.port_index]

            self_state.output_values = self.evaluate(input_values)
            assert len(self_state.output_values) == output_count # TODO: Error message.
            self_state.iteration += 1
            for i in range(output_count):
                for signal in self._output_ports[i].signals():
                    destination: Signal = signal.destination
                    destination.evaluate_outputs(state)

        return self_state.output_values

    def split(self) -> List[Operation]:
        # TODO: Check implementation.
        results = self.evaluate(self._input_ports)
        if all(isinstance(e, Operation) for e in results):
            return results
        return [self]

    @property
    def neighbours(self) -> List[Operation]:
        neighbours: List[Operation] = []
        for port in self._input_ports:
            for signal in port.signals:
                neighbours.append(signal.source.operation)

        for port in self._output_ports:
            for signal in port.signals:
                neighbours.append(signal.destination.operation)

        return neighbours

    def traverse(self) -> Operation:
        """Traverse the operation tree and return a generator with start point in the operation."""
        return breadth_first_search(self)

    # TODO: More stuff.