"""@package docstring
B-ASIC Signal Flow Graph Module.
TODO: More info.
"""

from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
from numbers import Number
from collections import defaultdict, deque

from b_asic.port import SignalSourceProvider, OutputPort
from b_asic.operation import Operation, AbstractOperation
from b_asic.signal import Signal
from b_asic.graph_component import GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output


GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)


class GraphIDGenerator:
    """A class that generates Graph IDs for objects."""

    _next_id_number: DefaultDict[TypeName, GraphIDNumber]

    def __init__(self, id_number_offset: GraphIDNumber = 0):
        self._next_id_number = defaultdict(lambda: id_number_offset)

    def next_id(self, type_name: TypeName) -> GraphID:
        """Return the next graph id for a certain graph id type."""
        self._next_id_number[type_name] += 1
        return type_name + str(self._next_id_number[type_name])


class SFG(AbstractOperation):
    """Signal flow graph.
    TODO: More info.
    """

    _components_by_id: Dict[GraphID, GraphComponent]
    _components_by_name: DefaultDict[Name, List[GraphComponent]]
    _graph_id_generator: GraphIDGenerator
    _input_operations: List[Input]
    _output_operations: List[Output]
    _original_components_added: Set[GraphComponent]
    _original_input_signals: Dict[Signal, int]
    _original_output_signals: Dict[Signal, int]

    def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [],
                 inputs: Sequence[Input] = [], outputs: Sequence[Output] = [],
                 id_number_offset: GraphIDNumber = 0, name: Name = "",
                 input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
        super().__init__(
            input_count=len(input_signals) + len(inputs),
            output_count=len(output_signals) + len(outputs),
            name=name,
            input_sources=input_sources)

        self._components_by_id = dict()
        self._components_by_name = defaultdict(list)
        self._components_in_dfs_order = []
        self._graph_id_generator = GraphIDGenerator(id_number_offset)
        self._input_operations = []
        self._output_operations = []
        # Maps original components to new copied components
        self._added_components_mapping = {}
        self._original_input_signals_indexes = {}
        self._original_output_signals_indexes = {}
        self._id_number_offset = id_number_offset

        # Setup input signals.
        for input_index, sig in enumerate(input_signals):
            assert sig not in self._added_components_mapping, "Duplicate input signals sent to SFG construcctor."

            new_input_op = self._add_component_copy_unconnected(Input())
            new_sig = self._add_component_copy_unconnected(sig)
            new_sig.set_source(new_input_op.output(0))

            self._input_operations.append(new_input_op)
            self._original_input_signals_indexes[sig] = input_index

        # Setup input operations, starting from indexes ater input signals.
        for input_index, input_op in enumerate(inputs, len(input_signals)):
            assert input_op not in self._added_components_mapping, "Duplicate input operations sent to SFG constructor."
            new_input_op = self._add_component_copy_unconnected(input_op)

            for sig in input_op.output(0).signals:
                assert sig not in self._added_components_mapping, "Duplicate input signals connected to input ports sent to SFG construcctor."
                new_sig = self._add_component_copy_unconnected(sig)
                new_sig.set_source(new_input_op.output(0))

                self._original_input_signals_indexes[sig] = input_index

            self._input_operations.append(new_input_op)

        # Setup output signals.
        for output_ind, sig in enumerate(output_signals):
            new_out = self._add_component_copy_unconnected(Output())
            if sig in self._added_components_mapping:
                # Signal already added when setting up inputs
                new_sig = self._added_components_mapping[sig]
                new_sig.set_destination(new_out.input(0))
            else:
                # New signal has to be created
                new_sig = self._add_component_copy_unconnected(sig)
                new_sig.set_destination(new_out.input(0))

            self._output_operations.append(new_out)
            self._original_output_signals_indexes[sig] = output_ind

        # Setup output operations, starting from indexes after output signals.
        for output_ind, output_op in enumerate(outputs, len(output_signals)):
            assert output_op not in self._added_components_mapping, "Duplicate output operations sent to SFG constructor."

            new_out = self._add_component_copy_unconnected(output_op)
            for sig in output_op.input(0).signals:
                if sig in self._added_components_mapping:
                    # Signal already added when setting up inputs
                    new_sig = self._added_components_mapping[sig]
                    new_sig.set_destination(new_out.input(0))
                else:
                    # New signal has to be created
                    new_sig = self._add_component_copy_unconnected(sig)
                    new_sig.set_destination(new_out.input(0))

                self._original_output_signals_indexes[sig] = output_ind

            self._output_operations.append(new_out)

        output_operations_set = set(self._output_operations)

        # Search the graph inwards from each input signal.
        for sig, input_index in self._original_input_signals_indexes.items():
            # Check if already added destination.
            new_sig = self._added_components_mapping[sig]
            if new_sig.destination is None:
                if sig.destination is None:
                    raise ValueError(
                        f"Input signal #{input_index} is missing destination in SFG")
                elif sig.destination.operation not in self._added_components_mapping:
                    self._copy_structure_from_operation_dfs(
                        sig.destination.operation)
            else:
                if new_sig.destination.operation in output_operations_set:
                    # Add directly connected input to output to dfs order list
                    self._components_in_dfs_order.extend([
                        new_sig.source.operation, new_sig, new_sig.destination.operation])

        # Search the graph inwards from each output signal.
        for sig, output_index in self._original_output_signals_indexes.items():
            # Check if already added source.
            new_sig = self._added_components_mapping[sig]
            if new_sig.source is None:
                if sig.source is None:
                    raise ValueError(
                        f"Output signal #{output_index} is missing source in SFG")
                if sig.source.operation not in self._added_components_mapping:
                    self._copy_structure_from_operation_dfs(
                        sig.source.operation)

    def __call__(self):
        return self.deep_copy()

    @property
    def type_name(self) -> TypeName:
        return "sfg"

    def evaluate(self, *args):
        if len(args) != self.input_count:
            raise ValueError(
                "Wrong number of inputs supplied to SFG for evaluation")
        for arg, op in zip(args, self._input_operations):
            op.value = arg

        result = []
        for op in self._output_operations:
            result.append(self._evaluate_source(op.input(0).signals[0].source))

        n = len(result)
        return None if n == 0 else result[0] if n == 1 else result

    def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
        assert i >= 0 and i < self.output_count, "Output index out of range"
        result = [None] * self.output_count
        result[i] = self._evaluate_source(
            self._output_operations[i].input(0).signals[0].source)
        return result

    def split(self) -> Iterable[Operation]:
        return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values())

    @property
    def components(self) -> Iterable[GraphComponent]:
        """Get all components of this graph in the dfs-traversal order."""
        return self._components_in_dfs_order

    def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
        """Find a graph object based on the entered Graph ID and return it. If no graph
        object with the entered ID was found then return None.

        Keyword arguments:
        graph_id: Graph ID of the wanted object.
        """
        return self._components_by_id.get(graph_id, None)

    def find_by_name(self, name: Name) -> List[GraphComponent]:
        """Find all graph objects that have the entered name and return them
        in a list. If no graph object with the entered name was found then return an
        empty list.

        Keyword arguments:
        name: Name of the wanted object.
        """
        return self._components_by_name.get(name, [])

    def deep_copy(self) -> "SFG":
        """Returns a deep copy of self."""
        copy = SFG(inputs=self._input_operations, outputs=self._output_operations,
                   id_number_offset=self._id_number_offset, name=super().name)

        return copy

    def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent:

        assert original_comp not in self._added_components_mapping, "Tried to add duplicate SFG component"

        new_comp = original_comp.copy_unconnected()

        self._added_components_mapping[original_comp] = new_comp
        self._components_by_id[self._graph_id_generator.next_id(
            new_comp.type_name)] = new_comp
        self._components_by_name[new_comp.name].append(new_comp)

        return new_comp

    def _copy_structure_from_operation_dfs(self, start_op: Operation):
        op_stack = deque([start_op])

        while op_stack:
            original_op = op_stack.pop()
            # Add or get the new copy of the operation..
            new_op = None
            if original_op not in self._added_components_mapping:
                new_op = self._add_component_copy_unconnected(original_op)
                self._components_in_dfs_order.append(new_op)
            else:
                new_op = self._added_components_mapping[original_op]

            # Connect input ports to new signals
            for original_input_port in original_op.inputs:
                if original_input_port.signal_count < 1:
                    raise ValueError("Unconnected input port in SFG")

                for original_signal in original_input_port.signals:

                    # Check if the signal is one of the SFG's input signals
                    if original_signal in self._original_input_signals_indexes:

                        # New signal already created during first step of constructor
                        new_signal = self._added_components_mapping[
                            original_signal]
                        new_signal.set_destination(
                            new_op.input(original_input_port.index))

                        self._components_in_dfs_order.extend(
                            [new_signal, new_signal.source.operation])

                    # Check if the signal has not been added before
                    elif original_signal not in self._added_components_mapping:
                        if original_signal.source is None:
                            raise ValueError(
                                "Dangling signal without source in SFG")

                        new_signal = self._add_component_copy_unconnected(
                            original_signal)
                        new_signal.set_destination(
                            new_op.input(original_input_port.index))

                        self._components_in_dfs_order.append(new_signal)

                        original_connected_op = original_signal.source.operation
                        # Check if connected Operation has been added before
                        if original_connected_op in self._added_components_mapping:
                            # Set source to the already added operations port
                            new_signal.set_source(
                                self._added_components_mapping[original_connected_op].output(
                                    original_signal.source.index))
                        else:
                            # Create new operation, set signal source to it
                            new_connected_op = self._add_component_copy_unconnected(
                                original_connected_op)
                            new_signal.set_source(new_connected_op.output(
                                original_signal.source.index))

                            self._components_in_dfs_order.append(
                                new_connected_op)

                            # Add connected operation to queue of operations to visit
                            op_stack.append(original_connected_op)

            # Connect output ports
            for original_output_port in original_op.outputs:

                for original_signal in original_output_port.signals:
                    # Check if the signal is one of the SFG's output signals.
                    if original_signal in self._original_output_signals_indexes:

                        # New signal already created during first step of constructor.
                        new_signal = self._added_components_mapping[
                            original_signal]
                        new_signal.set_source(
                            new_op.output(original_output_port.index))

                        self._components_in_dfs_order.extend(
                            [new_signal, new_signal.destination.operation])

                    # Check if signal has not been added before.
                    elif original_signal not in self._added_components_mapping:
                        if original_signal.source is None:
                            raise ValueError(
                                "Dangling signal without source in SFG")

                        new_signal = self._add_component_copy_unconnected(
                            original_signal)
                        new_signal.set_source(
                            new_op.output(original_output_port.index))

                        self._components_in_dfs_order.append(new_signal)

                        original_connected_op = original_signal.destination.operation
                        # Check if connected operation has been added.
                        if original_connected_op in self._added_components_mapping:
                            # Set destination to the already connected operations port
                            new_signal.set_destination(
                                self._added_components_mapping[original_connected_op].input(
                                    original_signal.destination.index))

                        else:
                            # Create new operation, set destination to it.
                            new_connected_op = self._add_component_copy_unconnected(
                                original_connected_op)
                            new_signal.set_destination(new_connected_op.input(
                                original_signal.destination.index))

                            self._components_in_dfs_order.append(
                                new_connected_op)

                            # Add connected operation to the queue of operations to visist
                            op_stack.append(original_connected_op)

    def _evaluate_source(self, src: OutputPort) -> Number:
        input_values = []
        for input_port in src.operation.inputs:
            input_src = input_port.signals[0].source
            input_values.append(self._evaluate_source(input_src))
        return src.operation.evaluate_output(src.index, input_values)