"""@package docstring
B-ASIC Signal Flow Graph Module.
TODO: More info.
"""

from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
from numbers import Number
from collections import defaultdict, deque

from b_asic.port import SignalSourceProvider, OutputPort
from b_asic.operation import Operation, AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap
from b_asic.signal import Signal
from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output


class GraphIDGenerator:
    """A class that generates Graph IDs for objects."""

    _next_id_number: DefaultDict[TypeName, GraphIDNumber]

    def __init__(self, id_number_offset: GraphIDNumber = 0):
        self._next_id_number = defaultdict(lambda: id_number_offset)

    def next_id(self, type_name: TypeName) -> GraphID:
        """Get the next graph id for a certain graph id type."""
        self._next_id_number[type_name] += 1
        return type_name + str(self._next_id_number[type_name])

    @property
    def id_number_offset(self) -> GraphIDNumber:
        """Get the graph id number offset of this generator."""
        return self._next_id_number.default_factory() # pylint: disable=not-callable


class SFG(AbstractOperation):
    """Signal flow graph.
    TODO: More info.
    """

    _components_by_id: Dict[GraphID, GraphComponent]
    _components_by_name: DefaultDict[Name, List[GraphComponent]]
    _components_ordered: List[GraphComponent]
    _operations_ordered: List[Operation]
    _graph_id_generator: GraphIDGenerator
    _input_operations: List[Input]
    _output_operations: List[Output]
    _original_components_to_new: Set[GraphComponent]
    _original_input_signals_to_indices: Dict[Signal, int]
    _original_output_signals_to_indices: Dict[Signal, int]

    def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, \
                 inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, \
                 id_number_offset: GraphIDNumber = 0, name: Name = "", \
                 input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
        input_signal_count = 0 if input_signals is None else len(input_signals)
        input_operation_count = 0 if inputs is None else len(inputs)
        output_signal_count = 0 if output_signals is None else len(output_signals)
        output_operation_count = 0 if outputs is None else len(outputs)
        super().__init__(input_count = input_signal_count + input_operation_count,
                         output_count = output_signal_count + output_operation_count,
                         name = name, input_sources = input_sources)

        self._components_by_id = dict()
        self._components_by_name = defaultdict(list)
        self._components_ordered = []
        self._operations_ordered = []
        self._graph_id_generator = GraphIDGenerator(id_number_offset)
        self._input_operations = []
        self._output_operations = []
        self._original_components_to_new = {}
        self._original_input_signals_to_indices = {}
        self._original_output_signals_to_indices = {}

        # Setup input signals.
        if input_signals is not None:
            for input_index, signal in enumerate(input_signals):
                assert signal not in self._original_components_to_new, "Duplicate input signals supplied to SFG construcctor."
                new_input_op = self._add_component_unconnected_copy(Input())
                new_signal = self._add_component_unconnected_copy(signal)
                new_signal.set_source(new_input_op.output(0))
                self._input_operations.append(new_input_op)
                self._original_input_signals_to_indices[signal] = input_index

        # Setup input operations, starting from indices ater input signals.
        if inputs is not None:
            for input_index, input_op in enumerate(inputs, input_signal_count):
                assert input_op not in self._original_components_to_new, "Duplicate input operations supplied to SFG constructor."
                new_input_op = self._add_component_unconnected_copy(input_op)
                for signal in input_op.output(0).signals:
                    assert signal not in self._original_components_to_new, "Duplicate input signals connected to input ports supplied to SFG construcctor."
                    new_signal = self._add_component_unconnected_copy(signal)
                    new_signal.set_source(new_input_op.output(0))
                    self._original_input_signals_to_indices[signal] = input_index

                self._input_operations.append(new_input_op)

        # Setup output signals.
        if output_signals is not None:
            for output_index, signal in enumerate(output_signals):
                new_output_op = self._add_component_unconnected_copy(Output())
                if signal in self._original_components_to_new:
                    # Signal was already added when setting up inputs.
                    new_signal = self._original_components_to_new[signal]
                    new_signal.set_destination(new_output_op.input(0))
                else:
                    # New signal has to be created.
                    new_signal = self._add_component_unconnected_copy(signal)
                    new_signal.set_destination(new_output_op.input(0))

                self._output_operations.append(new_output_op)
                self._original_output_signals_to_indices[signal] = output_index

        # Setup output operations, starting from indices after output signals.
        if outputs is not None:
            for output_index, output_op in enumerate(outputs, output_signal_count):
                assert output_op not in self._original_components_to_new, "Duplicate output operations supplied to SFG constructor."
                new_output_op = self._add_component_unconnected_copy(output_op)
                for signal in output_op.input(0).signals:
                    new_signal = None
                    if signal in self._original_components_to_new:
                        # Signal was already added when setting up inputs.
                        new_signal = self._original_components_to_new[signal]
                    else:
                        # New signal has to be created.
                        new_signal = self._add_component_unconnected_copy(signal)

                    new_signal.set_destination(new_output_op.input(0))
                    self._original_output_signals_to_indices[signal] = output_index

                self._output_operations.append(new_output_op)

        output_operations_set = set(self._output_operations)

        # Search the graph inwards from each input signal.
        for signal, input_index in self._original_input_signals_to_indices.items():
            # Check if already added destination.
            new_signal = self._original_components_to_new[signal]
            if new_signal.destination is None:
                if signal.destination is None:
                    raise ValueError(f"Input signal #{input_index} is missing destination in SFG")
                if signal.destination.operation not in self._original_components_to_new:
                    self._add_operation_connected_tree_copy(signal.destination.operation)
            elif new_signal.destination.operation in output_operations_set:
                # Add directly connected input to output to ordered list.
                self._components_ordered.extend([new_signal.source.operation, new_signal, new_signal.destination.operation])
                self._operations_ordered.extend([new_signal.source.operation, new_signal.destination.operation])

        # Search the graph inwards from each output signal.
        for signal, output_index in self._original_output_signals_to_indices.items():
            # Check if already added source.
            new_signal = self._original_components_to_new[signal]
            if new_signal.source is None:
                if signal.source is None:
                    raise ValueError(f"Output signal #{output_index} is missing source in SFG")
                if signal.source.operation not in self._original_components_to_new:
                    self._add_operation_connected_tree_copy(signal.source.operation)
    
    def __str__(self) -> str:
        """Get a string representation of this SFG."""
        output_string = ""
        for component in self._components_ordered:
            if isinstance(component, Operation):
                for key, value in self._components_by_id.items():
                    if value is component:
                        output_string += "id: " + key + ", name: "
                
                if component.name != None:
                    output_string += component.name + ", "
                else:
                    output_string += "-, "
                
                if component.type_name is "c":
                    output_string += "value: " + str(component.value) + ", input: ["
                else:
                    output_string += "input: [" 

                counter_input = 0
                for input in component.inputs:
                    counter_input += 1
                    for signal in input.signals:
                        for key, value in self._components_by_id.items():
                            if value is signal:
                                output_string += key + ", "
                
                if counter_input > 0:
                    output_string = output_string[:-2]
                output_string += "], output: ["
                counter_output = 0
                for output in component.outputs:
                    counter_output += 1
                    for signal in output.signals:
                        for key, value in self._components_by_id.items():
                            if value is signal:
                                output_string += key + ", "
                if counter_output > 0:
                    output_string = output_string[:-2]
                output_string += "]\n"

        return output_string

    def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG":
        """Get a new independent SFG instance that is identical to this SFG except without any of its external connections."""
        return SFG(inputs = self._input_operations, outputs = self._output_operations,
                   id_number_offset = self.id_number_offset, name = name, input_sources = src if src else None)

    @property
    def type_name(self) -> TypeName:
        return "sfg"

    def evaluate(self, *args):
        result = self.evaluate_outputs(args, {}, {}, "")
        n = len(result)
        return None if n == 0 else result[0] if n == 1 else result

    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
        if index < 0 or index >= self.output_count:
            raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
        if len(input_values) != self.input_count:
            raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected {self.input_count}, got {len(input_values)})")
        if results is None:
            results = {}
        if registers is None:
            registers = {}
        
        # Set the values of our input operations to the given input values.
        for op, arg in zip(self._input_operations, self.truncate_inputs(input_values)):
            op.value = arg
        
        value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix)
        results[self.key(index, prefix)] = value
        return value

    def split(self) -> Iterable[Operation]:
        return self.operations
    
    def copy_component(self, *args, **kwargs) -> GraphComponent:
        return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations,
                                      id_number_offset = self.id_number_offset, name = self.name)

    @property
    def id_number_offset(self) -> GraphIDNumber:
        """Get the graph id number offset of the graph id generator for this SFG."""
        return self._graph_id_generator.id_number_offset

    @property
    def components(self) -> Iterable[GraphComponent]:
        """Get all components of this graph in depth-first order."""
        return self._components_ordered

    @property
    def operations(self) -> Iterable[Operation]:
        """Get all operations of this graph in depth-first order."""
        return self._operations_ordered

    def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
        """Find the graph component with the specified ID.
        Returns None if the component was not found.

        Keyword arguments:
        graph_id: Graph ID of the desired component(s)
        """
        return self._components_by_id.get(graph_id, None)

    def find_by_name(self, name: Name) -> Sequence[GraphComponent]:
        """Find all graph components with the specified name.
        Returns an empty sequence if no components were found.

        Keyword arguments:
        name: Name of the desired component(s)
        """
        return self._components_by_name.get(name, [])

    def _add_component_unconnected_copy(self, original_component: GraphComponent) -> GraphComponent:
        assert original_component not in self._original_components_to_new, "Tried to add duplicate SFG component"
        new_component = original_component.copy_component()
        self._original_components_to_new[original_component] = new_component
        new_id = self._graph_id_generator.next_id(new_component.type_name)
        new_component.graph_id = new_id
        self._components_by_id[new_id] = new_component
        self._components_by_name[new_component.name].append(new_component)
        return new_component

    def _add_operation_connected_tree_copy(self, start_op: Operation) -> None:
        op_stack = deque([start_op])
        while op_stack:
            original_op = op_stack.pop()
            # Add or get the new copy of the operation.
            new_op = None
            if original_op not in self._original_components_to_new:
                new_op = self._add_component_unconnected_copy(original_op)
                self._components_ordered.append(new_op)
                self._operations_ordered.append(new_op)
            else:
                new_op = self._original_components_to_new[original_op]

            # Connect input ports to new signals.
            for original_input_port in original_op.inputs:
                if original_input_port.signal_count < 1:
                    raise ValueError("Unconnected input port in SFG")

                for original_signal in original_input_port.signals:
                    # Check if the signal is one of the SFG's input signals.
                    if original_signal in self._original_input_signals_to_indices:
                        # New signal already created during first step of constructor.
                        new_signal = self._original_components_to_new[original_signal]
                        new_signal.set_destination(new_op.input(original_input_port.index))
                        self._components_ordered.extend([new_signal, new_signal.source.operation])
                        self._operations_ordered.append(new_signal.source.operation)

                    # Check if the signal has not been added before.
                    elif original_signal not in self._original_components_to_new:
                        if original_signal.source is None:
                            raise ValueError("Dangling signal without source in SFG")
                        
                        new_signal = self._add_component_unconnected_copy(original_signal)
                        new_signal.set_destination(new_op.input(original_input_port.index))
                        self._components_ordered.append(new_signal)

                        original_connected_op = original_signal.source.operation
                        # Check if connected Operation has been added before.
                        if original_connected_op in self._original_components_to_new:
                            # Set source to the already added operations port.
                            new_signal.set_source(self._original_components_to_new[original_connected_op].output(original_signal.source.index))
                        else:
                            # Create new operation, set signal source to it.
                            new_connected_op = self._add_component_unconnected_copy(original_connected_op)
                            new_signal.set_source(new_connected_op.output(original_signal.source.index))
                            self._components_ordered.append(new_connected_op)
                            self._operations_ordered.append(new_connected_op)

                            # Add connected operation to queue of operations to visit.
                            op_stack.append(original_connected_op)

            # Connect output ports.
            for original_output_port in original_op.outputs:
                for original_signal in original_output_port.signals:
                    # Check if the signal is one of the SFG's output signals.
                    if original_signal in self._original_output_signals_to_indices:
                        # New signal already created during first step of constructor.
                        new_signal = self._original_components_to_new[original_signal]
                        new_signal.set_source(new_op.output(original_output_port.index))
                        self._components_ordered.extend([new_signal, new_signal.destination.operation])
                        self._operations_ordered.append(new_signal.destination.operation)

                    # Check if signal has not been added before.
                    elif original_signal not in self._original_components_to_new:
                        if original_signal.source is None:
                            raise ValueError("Dangling signal without source in SFG")

                        new_signal = self._add_component_unconnected_copy(original_signal)
                        new_signal.set_source(new_op.output(original_output_port.index))
                        self._components_ordered.append(new_signal)

                        original_connected_op = original_signal.destination.operation
                        # Check if connected operation has been added.
                        if original_connected_op in self._original_components_to_new:
                            # Set destination to the already connected operations port.
                            new_signal.set_destination(self._original_components_to_new[original_connected_op].input(original_signal.destination.index))
                        else:
                            # Create new operation, set destination to it.
                            new_connected_op = self._add_component_unconnected_copy(original_connected_op)
                            new_signal.set_destination(new_connected_op.input(original_signal.destination.index))
                            self._components_ordered.append(new_connected_op)
                            self._operations_ordered.append(new_connected_op)

                            # Add connected operation to the queue of operations to visit.
                            op_stack.append(original_connected_op)

    def replace_component(self, component: Operation, _component: Operation = None, _id: GraphID = None):
        """Find and replace all components matching either on GraphID, Type or both.
        Then return a new deepcopy of the sfg with the replaced component.

        Arguments:
        component: The new component(s), e.g Multiplication

        Keyword arguments:
        _component: The specific component to replace.
        _id: The GraphID to match the component to replace.
        """

        assert _component is not None or _id is not None, \
            "Define either operation to replace or GraphID of operation"

        if _id is not None:
            _component = self.find_by_id(_id)

        assert _component is not None and isinstance(_component, Operation), \
            "No operation matching the criteria found"
        assert _component.output_count == component.output_count, \
            "The output count may not differ between the operations"
        assert _component.input_count == component.input_count, \
            "The input count may not differ between the operations"

        for index_in, _inp in enumerate(_component.inputs):
            for _signal in _inp.signals:
                _signal.remove_destination()
                _signal.set_destination(component.input(index_in))
        
        for index_out, _out in enumerate(_component.outputs):
            for _signal in _out.signals:
                _signal.remove_source()
                _signal.set_source(component.output(index_out))

        # The old SFG will be deleted by Python GC
        return self()

    def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number:
        src_prefix = prefix
        if src_prefix:
            src_prefix += "."
        src_prefix += src.operation.graph_id

        key = src.operation.key(src.index, src_prefix)
        if key in results:
            value = results[key]
            if value is None:
                raise RuntimeError(f"Direct feedback loop detected when evaluating operation.")
            return value

        results[key] = src.operation.current_output(src.index, registers, src_prefix)
        input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs]
        value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix)
        results[key] = value
        return value