"""@package docstring
B-ASIC Signal Flow Graph Module.
TODO: More info.
"""

from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
from numbers import Number
from collections import defaultdict

from b_asic.port import SignalSourceProvider, OutputPort
from b_asic.operation import Operation, AbstractOperation
from b_asic.signal import Signal
from b_asic.graph_component import GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output


GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)


class GraphIDGenerator:
    """A class that generates Graph IDs for objects."""

    _next_id_number: DefaultDict[TypeName, GraphIDNumber]

    def __init__(self, id_number_offset: GraphIDNumber = 0):
        self._next_id_number = defaultdict(lambda: id_number_offset)

    def next_id(self, type_name: TypeName) -> GraphID:
        """Return the next graph id for a certain graph id type."""
        self._next_id_number[type_name] += 1
        return type_name + str(self._next_id_number[type_name])


class SFG(AbstractOperation):
    """Signal flow graph.
    TODO: More info.
    """
    
    _components_by_id: Dict[GraphID, GraphComponent]
    _components_by_name: DefaultDict[Name, List[GraphComponent]]
    _graph_id_generator: GraphIDGenerator
    _input_operations: List[Input]
    _output_operations: List[Output]
    _original_components_added: Set[GraphComponent]
    _original_input_signals: Dict[Signal, int]
    _original_output_signals: Dict[Signal, int]

    def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \
                 inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [], \
                 id_number_offset: GraphIDNumber = 0, name: Name = "", \
                 input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
        super().__init__(
            input_count = len(input_signals) + len(inputs),
            output_count = len(output_signals) + len(outputs),
            name = name,
            input_sources = input_sources)

        self._components_by_id = dict()
        self._components_by_name = defaultdict(list)
        self._graph_id_generator = GraphIDGenerator(id_number_offset)
        self._input_operations = []
        self._output_operations = []
        self._original_components_added = set()
        self._original_input_signals = {}
        self._original_output_signals = {}

        # Setup input operations and signals.
        for i, s in enumerate(input_signals):
            self._input_operations.append(self._add_component_copy_unconnected(Input()))
            self._original_input_signals[s] = i
        for i, op in enumerate(inputs, len(input_signals)):
            self._input_operations.append(self._add_component_copy_unconnected(op))
            for s in op.output(0).signals:
                self._original_input_signals[s] = i

        # Setup output operations and signals.
        for i, s in enumerate(output_signals):
            self._output_operations.append(self._add_component_copy_unconnected(Output()))
            self._original_output_signals[s] = i
        for i, op in enumerate(outputs, len(output_signals)):
            self._output_operations.append(self._add_component_copy_unconnected(op))
            for s in op.input(0).signals:
                self._original_output_signals[s] = i
        
        # Search the graph inwards from each input signal.
        for s, i in self._original_input_signals.items():
            if s.destination is None:
                raise ValueError(f"Input signal #{i} is missing destination in SFG")
            if s.destination.operation not in self._original_components_added:
                self._add_operation_copy_recursively(s.destination.operation)

        # Search the graph inwards from each output signal.
        for s, i in self._original_output_signals.items():
            if s.source is None:
                raise ValueError(f"Output signal #{i} is missing source in SFG")
            if s.source.operation not in self._original_components_added:
                self._add_operation_copy_recursively(s.source.operation)

        # Search the graph outwards from each operation.
        for op in operations:
            if op not in self._original_components_added:
                self._add_operation_copy_recursively(op)

    @property
    def type_name(self) -> TypeName:
        return "sfg"

    def evaluate(self, *args):
        if len(args) != self.input_count:
            raise ValueError("Wrong number of inputs supplied to SFG for evaluation")
        for arg, op in zip(args, self._input_operations):
            op.value = arg
        
        result = []
        for op in self._output_operations:
            result.append(self._evaluate_source(op.input(0).signals[0].source))

        n = len(result)
        return None if n == 0 else result[0] if n == 1 else result

    def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
        assert i >= 0 and i < self.output_count, "Output index out of range"
        result = [None] * self.output_count
        result[i] = self._evaluate_source(self._output_operations[i].input(0).signals[0].source)
        return result

    def split(self) -> Iterable[Operation]:
        return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values())

    @property
    def components(self) -> Iterable[GraphComponent]:
        """Get all components of this graph."""
        return self._components_by_id.values()

    def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
        """Find a graph object based on the entered Graph ID and return it. If no graph
        object with the entered ID was found then return None.

        Keyword arguments:
        graph_id: Graph ID of the wanted object.
        """
        return self._components_by_id.get(graph_id, None)

    def find_by_name(self, name: Name) -> List[GraphComponent]:
        """Find all graph objects that have the entered name and return them
        in a list. If no graph object with the entered name was found then return an
        empty list.

        Keyword arguments:
        name: Name of the wanted object.
        """
        return self._components_by_name.get(name, [])

    def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent:
        assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component"
        self._original_components_added.add(original_comp)

        new_comp = original_comp.copy_unconnected()
        self._components_by_id[self._graph_id_generator.next_id(new_comp.type_name)] = new_comp
        self._components_by_name[new_comp.name].append(new_comp)
        return new_comp

    def _add_operation_copy_recursively(self, original_op: Operation) -> Operation:
        # Add a copy of the operation without any connections.
        new_op = self._add_component_copy_unconnected(original_op)

        # Connect input ports.
        for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs):
            if original_input_port.signal_count < 1:
                raise ValueError("Unconnected input port in SFG")
            for original_signal in original_input_port.signals:
                if original_signal in self._original_input_signals: # Check if the signal is one of the SFG's input signals.
                    new_signal = self._add_component_copy_unconnected(original_signal)
                    new_signal.set_destination(new_input_port)
                    new_signal.set_source(self._input_operations[self._original_input_signals[original_signal]].output(0))
                elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added.
                    new_signal = self._add_component_copy_unconnected(original_signal)
                    new_signal.set_destination(new_input_port)
                    if original_signal.source is None:
                        raise ValueError("Dangling signal without source in SFG")
                    # Recursively add the connected operation.
                    new_connected_op = self._add_operation_copy_recursively(original_signal.source.operation)
                    new_signal.set_source(new_connected_op.output(original_signal.source.index))

        # Connect output ports.
        for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs):
            for original_signal in original_output_port.signals:
                if original_signal in self._original_output_signals: # Check if the signal is one of the SFG's output signals.
                    new_signal = self._add_component_copy_unconnected(original_signal)
                    new_signal.set_source(new_output_port)
                    new_signal.set_destination(self._output_operations[self._original_output_signals[original_signal]].input(0))
                elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added.
                    new_signal = self._add_component_copy_unconnected(original_signal)
                    new_signal.set_source(new_output_port)
                    if original_signal.destination is None:
                        raise ValueError("Dangling signal without destination in SFG")
                    # Recursively add the connected operation.
                    new_connected_op = self._add_operation_copy_recursively(original_signal.destination.operation)
                    new_signal.set_destination(new_connected_op.input(original_signal.destination.index))

        return new_op
    
    def _evaluate_source(self, src: OutputPort) -> Number:
        input_values = []
        for input_port in src.operation.inputs:
            input_src = input_port.signals[0].source
            input_values.append(self._evaluate_source(input_src))
        return src.operation.evaluate_output(src.index, input_values)