"""@package docstring
B-ASIC Signal Flow Graph Module.
TODO: More info.
"""

from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet
from numbers import Number
from collections import defaultdict, deque

from b_asic.port import SignalSourceProvider, OutputPort, InputPort
from b_asic.operation import Operation, AbstractOperation, MutableOutputMap, MutableRegisterMap
from b_asic.signal import Signal
from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output, Register
from b_asic.core_operations import Constant


class GraphIDGenerator:
    """A class that generates Graph IDs for objects."""

    _next_id_number: DefaultDict[TypeName, GraphIDNumber]

    def __init__(self, id_number_offset: GraphIDNumber = 0):
        self._next_id_number = defaultdict(lambda: id_number_offset)

    def next_id(self, type_name: TypeName) -> GraphID:
        """Get the next graph id for a certain graph id type."""
        self._next_id_number[type_name] += 1
        return type_name + str(self._next_id_number[type_name])

    @property
    def id_number_offset(self) -> GraphIDNumber:
        """Get the graph id number offset of this generator."""
        return self._next_id_number.default_factory()  # pylint: disable=not-callable


class SFG(AbstractOperation):
    """Signal flow graph.
    TODO: More info.
    """

    _components_by_id: Dict[GraphID, GraphComponent]
    _components_by_name: DefaultDict[Name, List[GraphComponent]]
    _components_ordered: List[GraphComponent]
    _operations_ordered: List[Operation]
    _graph_id_generator: GraphIDGenerator
    _input_operations: List[Input]
    _output_operations: List[Output]
    _original_components_to_new: MutableSet[GraphComponent]
    _original_input_signals_to_indices: Dict[Signal, int]
    _original_output_signals_to_indices: Dict[Signal, int]
    _precedence_list: Optional[List[List[OutputPort]]]

    def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None,
                 inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None,
                 id_number_offset: GraphIDNumber = 0, name: Name = "",
                 input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):

        input_signal_count = 0 if input_signals is None else len(input_signals)
        input_operation_count = 0 if inputs is None else len(inputs)
        output_signal_count = 0 if output_signals is None else len(
            output_signals)
        output_operation_count = 0 if outputs is None else len(outputs)
        super().__init__(input_count=input_signal_count + input_operation_count,
                         output_count=output_signal_count + output_operation_count,
                         name=name, input_sources=input_sources)

        self._components_by_id = dict()
        self._components_by_name = defaultdict(list)
        self._components_ordered = []
        self._operations_ordered = []
        self._graph_id_generator = GraphIDGenerator(id_number_offset)
        self._input_operations = []
        self._output_operations = []
        self._original_components_to_new = {}
        self._original_input_signals_to_indices = {}
        self._original_output_signals_to_indices = {}
        self._precedence_list = None

        # Setup input signals.
        if input_signals is not None:
            for input_index, signal in enumerate(input_signals):
                assert signal not in self._original_components_to_new, "Duplicate input signals supplied to SFG construcctor."
                new_input_op = self._add_component_unconnected_copy(Input())
                new_signal = self._add_component_unconnected_copy(signal)
                new_signal.set_source(new_input_op.output(0))
                self._input_operations.append(new_input_op)
                self._original_input_signals_to_indices[signal] = input_index

        # Setup input operations, starting from indices ater input signals.
        if inputs is not None:
            for input_index, input_op in enumerate(inputs, input_signal_count):
                assert input_op not in self._original_components_to_new, "Duplicate input operations supplied to SFG constructor."
                new_input_op = self._add_component_unconnected_copy(input_op)
                for signal in input_op.output(0).signals:
                    assert signal not in self._original_components_to_new, "Duplicate input signals connected to input ports supplied to SFG construcctor."
                    new_signal = self._add_component_unconnected_copy(signal)
                    new_signal.set_source(new_input_op.output(0))
                    self._original_input_signals_to_indices[signal] = input_index

                self._input_operations.append(new_input_op)

        # Setup output signals.
        if output_signals is not None:
            for output_index, signal in enumerate(output_signals):
                new_output_op = self._add_component_unconnected_copy(Output())
                if signal in self._original_components_to_new:
                    # Signal was already added when setting up inputs.
                    new_signal = self._original_components_to_new[signal]
                    new_signal.set_destination(new_output_op.input(0))
                else:
                    # New signal has to be created.
                    new_signal = self._add_component_unconnected_copy(signal)
                    new_signal.set_destination(new_output_op.input(0))

                self._output_operations.append(new_output_op)
                self._original_output_signals_to_indices[signal] = output_index

        # Setup output operations, starting from indices after output signals.
        if outputs is not None:
            for output_index, output_op in enumerate(outputs, output_signal_count):
                assert output_op not in self._original_components_to_new, "Duplicate output operations supplied to SFG constructor."
                new_output_op = self._add_component_unconnected_copy(output_op)
                for signal in output_op.input(0).signals:
                    new_signal = None
                    if signal in self._original_components_to_new:
                        # Signal was already added when setting up inputs.
                        new_signal = self._original_components_to_new[signal]
                    else:
                        # New signal has to be created.
                        new_signal = self._add_component_unconnected_copy(
                            signal)

                    new_signal.set_destination(new_output_op.input(0))
                    self._original_output_signals_to_indices[signal] = output_index

                self._output_operations.append(new_output_op)

        output_operations_set = set(self._output_operations)

        # Search the graph inwards from each input signal.
        for signal, input_index in self._original_input_signals_to_indices.items():
            # Check if already added destination.
            new_signal = self._original_components_to_new[signal]
            if new_signal.destination is None:
                if signal.destination is None:
                    raise ValueError(
                        f"Input signal #{input_index} is missing destination in SFG")
                if signal.destination.operation not in self._original_components_to_new:
                    self._add_operation_connected_tree_copy(
                        signal.destination.operation)
            elif new_signal.destination.operation in output_operations_set:
                # Add directly connected input to output to ordered list.
                self._components_ordered.extend(
                    [new_signal.source.operation, new_signal, new_signal.destination.operation])
                self._operations_ordered.extend(
                    [new_signal.source.operation, new_signal.destination.operation])

        # Search the graph inwards from each output signal.
        for signal, output_index in self._original_output_signals_to_indices.items():
            # Check if already added source.
            new_signal = self._original_components_to_new[signal]
            if new_signal.source is None:
                if signal.source is None:
                    raise ValueError(
                        f"Output signal #{output_index} is missing source in SFG")
                if signal.source.operation not in self._original_components_to_new:
                    self._add_operation_connected_tree_copy(
                        signal.source.operation)

    def __str__(self) -> str:
        """Get a string representation of this SFG."""
        output_string = ""
        for component in self._components_ordered:
            if isinstance(component, Operation):
                for key, value in self._components_by_id.items():
                    if value is component:
                        output_string += "id: " + key + ", name: "

                if component.name != None:
                    output_string += component.name + ", "
                else:
                    output_string += "-, "

                if isinstance(component, Constant):
                    output_string += "value: " + \
                        str(component.value) + ", input: ["
                else:
                    output_string += "input: ["

                counter_input = 0
                for input in component.inputs:
                    counter_input += 1
                    for signal in input.signals:
                        for key, value in self._components_by_id.items():
                            if value is signal:
                                output_string += key + ", "

                if counter_input > 0:
                    output_string = output_string[:-2]
                output_string += "], output: ["
                counter_output = 0
                for output in component.outputs:
                    counter_output += 1
                    for signal in output.signals:
                        for key, value in self._components_by_id.items():
                            if value is signal:
                                output_string += key + ", "
                if counter_output > 0:
                    output_string = output_string[:-2]
                output_string += "]\n"

        return output_string

    def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG":
        """Get a new independent SFG instance that is identical to this SFG except without any of its external connections."""
        return SFG(inputs=self._input_operations, outputs=self._output_operations,
                   id_number_offset=self.id_number_offset, name=name, input_sources=src if src else None)

    @classmethod
    def type_name(cls) -> TypeName:
        return "sfg"

    def evaluate(self, *args):
        result = self.evaluate_outputs(args, {}, {}, "")
        n = len(result)
        return None if n == 0 else result[0] if n == 1 else result

    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number:
        if index < 0 or index >= self.output_count:
            raise IndexError(
                f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
        if len(input_values) != self.input_count:
            raise ValueError(
                f"Wrong number of inputs supplied to SFG for evaluation (expected {self.input_count}, got {len(input_values)})")
        if results is None:
            results = {}
        if registers is None:
            registers = {}

        # Set the values of our input operations to the given input values.
        for op, arg in zip(self._input_operations, self.truncate_inputs(input_values)):
            op.value = arg

        value = self._evaluate_source(self._output_operations[index].input(
            0).signals[0].source, results, registers, prefix)
        results[self.key(index, prefix)] = value
        return value

    def connect_external_signals_to_components(self) -> bool:
        """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG 
        it is a component off, causing it to become invalid afterwards. Returns True if succesful, False otherwise. """
        if len(self.inputs) != len(self.input_operations):
            raise IndexError(f"Number of inputs does not match the number of input_operations in SFG.")
        if len(self.outputs) != len(self.output_operations):
            raise IndexError(f"Number of outputs does not match the number of output_operations SFG.")
        if len(self.input_signals) == 0:
            return False
        if len(self.output_signals) == 0:
            return False

        # For each input_signal, connect it to the corresponding operation
        for port, input_operation in zip(self.inputs, self.input_operations):
            dest = input_operation.output(0).signals[0].destination
            dest.clear()
            port.signals[0].set_destination(dest)
        # For each output_signal, connect it to the corresponding operation    
        for port, output_operation in zip(self.outputs, self.output_operations):
            src = output_operation.input(0).signals[0].source
            src.clear()
            port.signals[0].set_source(src)
        return True

    @property
    def input_operations(self) -> Sequence[Operation]:
        """Get the internal input operations in the same order as their respective input ports."""
        return self._input_operations

    @property
    def output_operations(self) -> Sequence[Operation]:
        """Get the internal output operations in the same order as their respective output ports."""
        return self._output_operations

    def split(self) -> Iterable[Operation]:
        return self.operations

    def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
        if output_index < 0 or output_index >= self.output_count:
            raise IndexError(
                f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")

        input_indexes_required = []
        sfg_input_operations_to_indexes = {
            input_op: index for index, input_op in enumerate(self._input_operations)}
        output_op = self._output_operations[output_index]
        queue = deque([output_op])
        visited = set([output_op])
        while queue:
            op = queue.popleft()
            if isinstance(op, Input):
                if op in sfg_input_operations_to_indexes:
                    input_indexes_required.append(
                        sfg_input_operations_to_indexes[op])
                    del sfg_input_operations_to_indexes[op]

            for input_port in op.inputs:
                for signal in input_port.signals:
                    if signal.source is not None:
                        new_op = signal.source.operation
                        if new_op not in visited:
                            queue.append(new_op)
                            visited.add(new_op)

        return input_indexes_required

    def copy_component(self, *args, **kwargs) -> GraphComponent:
        return super().copy_component(*args, **kwargs, inputs=self._input_operations, outputs=self._output_operations,
                                      id_number_offset=self.id_number_offset, name=self.name)

    @property
    def id_number_offset(self) -> GraphIDNumber:
        """Get the graph id number offset of the graph id generator for this SFG."""
        return self._graph_id_generator.id_number_offset

    @property
    def components(self) -> Iterable[GraphComponent]:
        """Get all components of this graph in depth-first order."""
        return self._components_ordered

    @property
    def operations(self) -> Iterable[Operation]:
        """Get all operations of this graph in depth-first order."""
        return self._operations_ordered

    def get_components_with_type_name(self, type_name: TypeName) -> List[GraphComponent]:
        """Get a list with all components in this graph with the specified type_name.

        Keyword arguments:
        type_name: The type_name of the desired components.
        """
        i = self.id_number_offset + 1
        components = []
        found_comp = self.find_by_id(type_name + str(i))
        while found_comp is not None:
            components.append(found_comp)
            i += 1
            found_comp = self.find_by_id(type_name + str(i))

        return components

    def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
        """Find the graph component with the specified ID.
        Returns None if the component was not found.

        Keyword arguments:
        graph_id: Graph ID of the desired component.
        """
        return self._components_by_id.get(graph_id, None)

    def find_by_name(self, name: Name) -> Sequence[GraphComponent]:
        """Find all graph components with the specified name.
        Returns an empty sequence if no components were found.

        Keyword arguments:
        name: Name of the desired component(s)
        """
        return self._components_by_name.get(name, [])

    def _add_component_unconnected_copy(self, original_component: GraphComponent) -> GraphComponent:
        assert original_component not in self._original_components_to_new, "Tried to add duplicate SFG component"
        new_component = original_component.copy_component()
        self._original_components_to_new[original_component] = new_component
        new_id = self._graph_id_generator.next_id(new_component.type_name())
        new_component.graph_id = new_id
        self._components_by_id[new_id] = new_component
        self._components_by_name[new_component.name].append(new_component)
        return new_component

    def _add_operation_connected_tree_copy(self, start_op: Operation) -> None:
        op_stack = deque([start_op])
        while op_stack:
            original_op = op_stack.pop()
            # Add or get the new copy of the operation.
            new_op = None
            if original_op not in self._original_components_to_new:
                new_op = self._add_component_unconnected_copy(original_op)
                self._components_ordered.append(new_op)
                self._operations_ordered.append(new_op)
            else:
                new_op = self._original_components_to_new[original_op]

            # Connect input ports to new signals.
            for original_input_port in original_op.inputs:
                if original_input_port.signal_count < 1:
                    raise ValueError("Unconnected input port in SFG")

                for original_signal in original_input_port.signals:
                    # Check if the signal is one of the SFG's input signals.
                    if original_signal in self._original_input_signals_to_indices:
                        # New signal already created during first step of constructor.
                        new_signal = self._original_components_to_new[original_signal]
                        new_signal.set_destination(
                            new_op.input(original_input_port.index))
                        self._components_ordered.extend(
                            [new_signal, new_signal.source.operation])
                        self._operations_ordered.append(
                            new_signal.source.operation)

                    # Check if the signal has not been added before.
                    elif original_signal not in self._original_components_to_new:
                        if original_signal.source is None:
                            raise ValueError(
                                "Dangling signal without source in SFG")

                        new_signal = self._add_component_unconnected_copy(
                            original_signal)
                        new_signal.set_destination(
                            new_op.input(original_input_port.index))
                        self._components_ordered.append(new_signal)

                        original_connected_op = original_signal.source.operation
                        # Check if connected Operation has been added before.
                        if original_connected_op in self._original_components_to_new:
                            # Set source to the already added operations port.
                            new_signal.set_source(self._original_components_to_new[original_connected_op].output(
                                original_signal.source.index))
                        else:
                            # Create new operation, set signal source to it.
                            new_connected_op = self._add_component_unconnected_copy(
                                original_connected_op)
                            new_signal.set_source(new_connected_op.output(
                                original_signal.source.index))
                            self._components_ordered.append(new_connected_op)
                            self._operations_ordered.append(new_connected_op)

                            # Add connected operation to queue of operations to visit.
                            op_stack.append(original_connected_op)

            # Connect output ports.
            for original_output_port in original_op.outputs:
                for original_signal in original_output_port.signals:
                    # Check if the signal is one of the SFG's output signals.
                    if original_signal in self._original_output_signals_to_indices:
                        # New signal already created during first step of constructor.
                        new_signal = self._original_components_to_new[original_signal]
                        new_signal.set_source(
                            new_op.output(original_output_port.index))
                        self._components_ordered.extend(
                            [new_signal, new_signal.destination.operation])
                        self._operations_ordered.append(
                            new_signal.destination.operation)

                    # Check if signal has not been added before.
                    elif original_signal not in self._original_components_to_new:
                        if original_signal.source is None:
                            raise ValueError(
                                "Dangling signal without source in SFG")

                        new_signal = self._add_component_unconnected_copy(
                            original_signal)
                        new_signal.set_source(
                            new_op.output(original_output_port.index))
                        self._components_ordered.append(new_signal)

                        original_connected_op = original_signal.destination.operation
                        # Check if connected operation has been added.
                        if original_connected_op in self._original_components_to_new:
                            # Set destination to the already connected operations port.
                            new_signal.set_destination(self._original_components_to_new[original_connected_op].input(
                                original_signal.destination.index))
                        else:
                            # Create new operation, set destination to it.
                            new_connected_op = self._add_component_unconnected_copy(
                                original_connected_op)
                            new_signal.set_destination(new_connected_op.input(
                                original_signal.destination.index))
                            self._components_ordered.append(new_connected_op)
                            self._operations_ordered.append(new_connected_op)

                            # Add connected operation to the queue of operations to visit.
                            op_stack.append(original_connected_op)

    def replace_component(self, component: Operation, _id: GraphID):
        """Find and replace all components matching either on GraphID, Type or both.
        Then return a new deepcopy of the sfg with the replaced component.

        Arguments:
        component: The new component(s), e.g Multiplication
        _id: The GraphID to match the component to replace.
        """

        _sfg_copy = self()
        _component = _sfg_copy.find_by_id(_id)

        assert _component is not None and isinstance(_component, Operation), \
            "No operation matching the criteria found"
        assert _component.output_count == component.output_count, \
            "The output count may not differ between the operations"
        assert _component.input_count == component.input_count, \
            "The input count may not differ between the operations"

        for index_in, _inp in enumerate(_component.inputs):
            for _signal in _inp.signals:
                _signal.remove_destination()
                _signal.set_destination(component.input(index_in))

        for index_out, _out in enumerate(_component.outputs):
            for _signal in _out.signals:
                _signal.remove_source()
                _signal.set_source(component.output(index_out))

        # The old SFG will be deleted by Python GC
        return _sfg_copy()

    def insert_operation(self, component: Operation, output_comp_id: GraphID):
        """Insert an operation in the SFG after a given source operation.
        The source operation output count must match the input count of the operation as well as the output
        Then return a new deepcopy of the sfg with the inserted component.

        Arguments:
        component: The new component, e.g Multiplication.
        output_comp_id: The source operation GraphID to connect from.
        """

        # Preserve the original SFG by creating a copy.
        sfg_copy = self()
        output_comp = sfg_copy.find_by_id(output_comp_id)
        if output_comp is None:
            return None

        assert not isinstance(output_comp, Output), \
            "Source operation can not be an output operation."
        assert len(output_comp.output_signals) == component.input_count, \
            "Source operation output count does not match input count for component."
        assert len(output_comp.output_signals) == component.output_count, \
            "Destination operation input count does not match output for component."

        for index, signal_in in enumerate(output_comp.output_signals):
            destination = signal_in.destination
            signal_in.set_destination(component.input(index))
            destination.connect(component.output(index))

        # Recreate the newly coupled SFG so that all attributes are correct.
        return sfg_copy()

    def explode(self) -> Tuple[Sequence[Signal, Sequence[Signal]], Sequence[Tuple[Signal, Sequence[Signal]]]:
    """Destroy the sfg by making it unusable in the future and 
    return all of the intermidetry operations, the input operations and the output operations.
    """
    return

    def _evaluate_source(self, src: OutputPort, results: MutableOutputMap, registers: MutableRegisterMap, prefix: str) -> Number:
        src_prefix = prefix
        if src_prefix:
            src_prefix += "."
        src_prefix += src.operation.graph_id

        key = src.operation.key(src.index, src_prefix)
        if key in results:
            value = results[key]
            if value is None:
                raise RuntimeError(
                    f"Direct feedback loop detected when evaluating operation.")
            return value

        results[key] = src.operation.current_output(
            src.index, registers, src_prefix)
        input_values = [self._evaluate_source(
            input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs]
        value = src.operation.evaluate_output(
            src.index, input_values, results, registers, src_prefix)
        results[key] = value
        return value

    def get_precedence_list(self) -> List[List[OutputPort]]:
        """Returns a Precedence list of the SFG where each element in n:th the list consists
        of elements that are executed in the n:th step. If the precedence list already has been
        calculated for the current SFG then returns the cached version."""
        if self._precedence_list is not None:
            return self._precedence_list

        # Find all operations with only outputs and no inputs.
        no_input_ops = list(filter(lambda op: op.input_count == 0, self.operations))
        reg_ops = self.get_components_with_type_name(Register.type_name())

        # Find all first iter output ports for precedence
        first_iter_ports = [op.output(i) for op in (no_input_ops + reg_ops) for i in range(op.output_count)]

        self._precedence_list = self._traverse_for_precedence_list(first_iter_ports)

        return self._precedence_list

    def _traverse_for_precedence_list(self, first_iter_ports):
        # Find dependencies of output ports and input ports.
        outports_per_inport = defaultdict(list)
        remaining_inports_per_outport = dict()
        for op in self.operations:
            op_inputs = op.inputs
            for out_i, outport in enumerate(op.outputs):
                dependendent_indexes = op.inputs_required_for_output(out_i)
                remaining_inports_per_outport[outport] = len(dependendent_indexes)
                for in_i in dependendent_indexes:
                    outports_per_inport[op_inputs[in_i]].append(outport)

        # Traverse output ports for precedence
        curr_iter_ports = first_iter_ports
        precedence_list = []

        while curr_iter_ports:
            # Add the found ports to the current iter
            precedence_list.append(curr_iter_ports)

            next_iter_ports = []

            for outport in curr_iter_ports:
                for signal in outport.signals:
                    new_inport = signal.destination
                    # Don't traverse over Registers
                    if new_inport is not None and not isinstance(new_inport.operation, Register):
                        for new_outport in outports_per_inport[new_inport]:
                            remaining_inports_per_outport[new_outport] -= 1
                            if remaining_inports_per_outport[new_outport] == 0:
                                next_iter_ports.append(new_outport)

            curr_iter_ports = next_iter_ports

        return precedence_list