import re
from typing import Dict, List, Optional, Set, Tuple, Union

import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.axes import Axes
from matplotlib.ticker import MaxNLocator

from b_asic.process import Process


# From https://stackoverflow.com/questions/2669059/how-to-sort-alpha-numeric-set-in-python
def _sorted_nicely(to_be_sorted):
    """Sort the given iterable in the way that humans expect."""
    convert = lambda text: int(text) if text.isdigit() else text
    alphanum_key = lambda key: [
        convert(c) for c in re.split('([0-9]+)', str(key))
    ]
    return sorted(to_be_sorted, key=alphanum_key)


def draw_exclusion_graph_coloring(
    exclusion_graph: nx.Graph,
    color_dict: Dict[Process, int],
    ax: Optional[Axes] = None,
    color_list: Optional[
        Union[List[str], List[Tuple[float, float, float]]]
    ] = None,
):
    """
    Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assigment

    .. code-block:: python

        _, ax = plt.subplots(1, 1)
        collection = ProcessCollection(...)
        exclusion_graph = collection.create_exclusion_graph_from_overlap()
        color_dict = nx.greedy_color(exclusion_graph)
        draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[0])
        plt.show()

    Parameters
    ----------
    exclusion_graph : nx.Graph
        A nx.Graph exclusion graph object that is to be drawn.

    color_dict : dictionary
        A color dictionary where keys are Process objects and where values are integers representing colors. These
        dictionaries are automatically generated by :func:`networkx.algorithms.coloring.greedy_color`.

    ax : :class:`matplotlib.axes.Axes`, optional
        A Matplotlib Axes object to draw the exclusion graph

    color_list : Optional[Union[List[str], List[Tuple[float,float,float]]]]
    """
    COLOR_LIST = [
        '#aa0000',
        '#00aa00',
        '#0000ff',
        '#ff00aa',
        '#ffaa00',
        '#00ffaa',
        '#aaff00',
        '#aa00ff',
        '#00aaff',
        '#ff0000',
        '#00ff00',
        '#0000aa',
        '#aaaa00',
        '#aa00aa',
        '#00aaaa',
    ]
    node_color_dict = {}
    if color_list is None:
        node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()}
    else:
        node_color_dict = {k: color_list[v] for k, v in color_dict.items()}
    node_color_list = [node_color_dict[node] for node in exclusion_graph]
    nx.draw_networkx(
        exclusion_graph,
        node_color=node_color_list,
        ax=ax,
        pos=nx.spring_layout(exclusion_graph, seed=1),
    )


class ProcessCollection:
    """
    Collection of one or more processes

    Parameters
    ----------
    collection : set of :class:`~b_asic.process.Process` objects
        The Process objects forming this ProcessCollection.
    schedule_time : int, default: 0
        Length of the time-axis in the generated graph.
    cyclic : bool, default: False
        If the processes operates cyclically, i.e., if time 0 == time *schedule_time*.
    """

    def __init__(
        self,
        collection: Set[Process],
        schedule_time: int,
        cyclic: bool = False,
    ):
        self._collection = collection
        self._schedule_time = schedule_time
        self._cyclic = cyclic

    def add_process(self, process: Process):
        """
        Add a new process to this process collection.

        Parameters
        ----------
        process : Process
            The process object to be added to the collection
        """
        self._collection.add(process)

    def draw_lifetime_chart(
        self,
        ax: Optional[Axes] = None,
        show_name: bool = True,
    ):
        """
        Use matplotlib.pyplot to generate a process variable lifetime chart from this process collection.

        Parameters
        ----------
        ax : :class:`matplotlib.axes.Axes`, optional
            Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this will
            return a new axes object on return.
        show_name : bool, default: True
            Show name of all processes in the lifetime chart.

        Returns
        -------
            ax: Associated Matplotlib Axes (or array of Axes) object
        """

        # Setup the Axes object
        if ax is None:
            _, _ax = plt.subplots()
        else:
            _ax = ax

        # Draw the lifetime chart
        PAD_L, PAD_R = 0.05, 0.05
        max_execution_time = max(
            process.execution_time for process in self._collection
        )
        if max_execution_time > self._schedule_time:
            # Schedule time needs to be greater than or equal to the maximum process life time
            raise KeyError(
                f'Error: Schedule time: {self._schedule_time} < Max execution'
                f' time: {max_execution_time}'
            )
        for i, process in enumerate(_sorted_nicely(self._collection)):
            bar_start = process.start_time % self._schedule_time
            bar_end = (
                process.start_time + process.execution_time
            ) % self._schedule_time
            if bar_end > bar_start:
                _ax.broken_barh(
                    [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)],
                    (i + 0.55, 0.9),
                )
            else:  # bar_end < bar_start
                if bar_end != 0:
                    _ax.broken_barh(
                        [
                            (
                                PAD_L + bar_start,
                                self._schedule_time - bar_start - PAD_L,
                            )
                        ],
                        (i + 0.55, 0.9),
                    )
                    _ax.broken_barh([(0, bar_end - PAD_R)], (i + 0.55, 0.9))
                else:
                    _ax.broken_barh(
                        [
                            (
                                PAD_L + bar_start,
                                self._schedule_time
                                - bar_start
                                - PAD_L
                                - PAD_R,
                            )
                        ],
                        (i + 0.55, 0.9),
                    )
            if show_name:
                _ax.annotate(
                    str(process),
                    (bar_start + PAD_L + 0.025, i + 1.00),
                    va="center",
                )
        _ax.grid(True)

        _ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        _ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        _ax.set_xlim(0, self._schedule_time)
        _ax.set_ylim(0.25, len(self._collection) + 0.75)
        return _ax

    def create_exclusion_graph_from_overlap(
        self, add_name: bool = True
    ) -> nx.Graph:
        """
        Generate exclusion graph based on processes overlaping in time

        Parameters
        ----------
        add_name : bool, default: True
            Add name of all processes as a node attribute in the exclusion graph.

        Returns
        -------
            An nx.Graph exclusion graph where nodes are processes and arcs
            between two processes indicated overlap in time
        """
        exclusion_graph = nx.Graph()
        exclusion_graph.add_nodes_from(self._collection)
        for process1 in self._collection:
            for process2 in self._collection:
                if process1 == process2:
                    continue
                else:
                    t1 = set(
                        range(
                            process1.start_time,
                            process1.start_time + process1.execution_time,
                        )
                    )
                    t2 = set(
                        range(
                            process2.start_time,
                            process2.start_time + process2.execution_time,
                        )
                    )
                    if t1.intersection(t2):
                        exclusion_graph.add_edge(process1, process2)
        return exclusion_graph

    def split(
        self,
        heuristic: str = "graph_color",
        read_ports: Optional[int] = None,
        write_ports: Optional[int] = None,
        total_ports: Optional[int] = None,
    ) -> Set["ProcessCollection"]:
        """
        Split this process storage based on some heuristic.

        Parameters
        ----------
        heuristic : str, default: "graph_color"
            The heuristic used when spliting this ProcessCollection.
            Valid options are:
                * "graph_color"
                * "..."
        read_ports : int, optional
            The number of read ports used when spliting process collection based on memory variable access.
        write_ports : int, optional
            The number of write ports used when spliting process collection based on memory variable access.
        total_ports : int, optional
            The total number of ports used when spliting process collection based on memory variable access.

        Returns
        -------
        A set of new ProcessColleciton objects with the process spliting.
        """
        if total_ports is None:
            if read_ports is None or write_ports is None:
                raise ValueError("inteligent quote")
            else:
                total_ports = read_ports + write_ports
        else:
            read_ports = total_ports if read_ports is None else read_ports
            write_ports = total_ports if write_ports is None else write_ports

        if heuristic == "graph_color":
            return self._split_graph_color(
                read_ports, write_ports, total_ports
            )
        else:
            raise ValueError("Invalid heuristic provided")

    def _split_graph_color(
        self, read_ports: int, write_ports: int, total_ports: int
    ) -> Set["ProcessCollection"]:
        """
        Parameters
        ----------
        read_ports : int, optional
            The number of read ports used when spliting process collection based on memory variable access.
        write_ports : int, optional
            The number of write ports used when spliting process collection based on memory variable access.
        total_ports : int, optional
            The total number of ports used when spliting process collection based on memory variable access.
        """
        if read_ports != 1 or write_ports != 1:
            raise ValueError(
                "Spliting with read and write ports not equal to one with the"
                " graph coloring heuristic does not make sense."
            )
        if total_ports not in (1, 2):
            raise ValueError(
                "Total ports should be either 1 (non-concurent reads/writes)"
                " or 2 (concurrent read/writes) for graph coloring heuristic."
            )

        # Create new exclusion graph. Nodes are Processes
        exclusion_graph = nx.Graph()
        exclusion_graph.add_nodes_from(self._collection)

        # Add exclusions (arcs) between processes in the exclusion graph
        for node1 in exclusion_graph:
            for node2 in exclusion_graph:
                if node1 == node2:
                    continue
                else:
                    node1_stop_time = node1.start_time + node1.execution_time
                    node2_stop_time = node2.start_time + node2.execution_time
                    if total_ports == 1:
                        # Single-port assignment
                        if node1.start_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1.start_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                    else:
                        # Dual-port assignment
                        if node1.start_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)

        # Perform assignment
        coloring = nx.coloring.greedy_color(exclusion_graph)
        draw_exclusion_graph_coloring(exclusion_graph, coloring)
        # process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1)
        process_collection_set_list = [
            set() for _ in range(max(coloring.values()) + 1)
        ]
        for process, color in coloring.items():
            process_collection_set_list[color].add(process)
        return {
            ProcessCollection(
                process_collection_set, self._schedule_time, self._cyclic
            )
            for process_collection_set in process_collection_set_list
        }