import io
import re
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union

import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.axes import Axes
from matplotlib.ticker import MaxNLocator

from b_asic._preferences import LATENCY_COLOR
from b_asic.process import Process

# Default latency coloring RGB tuple
_LATENCY_COLOR = tuple(c / 255 for c in LATENCY_COLOR)

#
# Human-intuitive sorting:
# https://stackoverflow.com/questions/2669059/how-to-sort-alpha-numeric-set-in-python
#
# Typing '_T' to help Pyright propagate type-information
#
_T = TypeVar('_T')


def _sorted_nicely(to_be_sorted: Iterable[_T]) -> List[_T]:
    """Sort the given iterable in the way that humans expect."""
    convert = lambda text: int(text) if text.isdigit() else text
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', str(key))]
    return sorted(to_be_sorted, key=alphanum_key)


def draw_exclusion_graph_coloring(
    exclusion_graph: nx.Graph,
    color_dict: Dict[Process, int],
    ax: Optional[Axes] = None,
    color_list: Optional[Union[List[str], List[Tuple[float, float, float]]]] = None,
):
    """
    Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment

    .. code-block:: python

        _, ax = plt.subplots(1, 1)
        collection = ProcessCollection(...)
        exclusion_graph = collection.create_exclusion_graph_from_overlap()
        color_dict = nx.greedy_color(exclusion_graph)
        draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[0])
        plt.show()

    Parameters
    ----------
    exclusion_graph : nx.Graph
        A nx.Graph exclusion graph object that is to be drawn.

    color_dict : dictionary
        A color dictionary where keys are Process objects and where values are integers representing colors. These
        dictionaries are automatically generated by :func:`networkx.algorithms.coloring.greedy_color`.

    ax : :class:`matplotlib.axes.Axes`, optional
        A Matplotlib Axes object to draw the exclusion graph

    color_list : Optional[Union[List[str], List[Tuple[float,float,float]]]]
    """
    COLOR_LIST = [
        '#aa0000',
        '#00aa00',
        '#0000ff',
        '#ff00aa',
        '#ffaa00',
        '#00ffaa',
        '#aaff00',
        '#aa00ff',
        '#00aaff',
        '#ff0000',
        '#00ff00',
        '#0000aa',
        '#aaaa00',
        '#aa00aa',
        '#00aaaa',
    ]
    if color_list is None:
        node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()}
    else:
        node_color_dict = {k: color_list[v] for k, v in color_dict.items()}
    node_color_list = [node_color_dict[node] for node in exclusion_graph]
    nx.draw_networkx(
        exclusion_graph,
        node_color=node_color_list,
        ax=ax,
        pos=nx.spring_layout(exclusion_graph, seed=1),
    )


class ProcessCollection:
    """
    Collection of one or more processes

    Parameters
    ----------
    collection : set of :class:`~b_asic.process.Process` objects
        The Process objects forming this ProcessCollection.
    schedule_time : int
        Length of the time-axis in the generated graph.
    cyclic : bool, default: False
        If the processes operates cyclically, i.e., if time 0 == time *schedule_time*.
    """

    def __init__(
        self,
        collection: Set[Process],
        schedule_time: int,
        cyclic: bool = False,
    ):
        self._collection = collection
        self._schedule_time = schedule_time
        self._cyclic = cyclic

    @property
    def collection(self):
        return self._collection

    def __len__(self):
        return len(self._collection)

    def add_process(self, process: Process):
        """
        Add a new process to this process collection.

        Parameters
        ----------
        process : Process
            The process object to be added to the collection
        """
        self._collection.add(process)

    def plot(
        self,
        ax: Optional[Axes] = None,
        show_name: bool = True,
        bar_color: Union[str, Tuple[float, ...]] = _LATENCY_COLOR,
        marker_color: Union[str, Tuple[float, ...]] = "black",
        marker_read: str = "X",
        marker_write: str = "o",
        show_markers: bool = True,
    ):
        """
        Use matplotlib.pyplot to generate a process variable lifetime chart from this process collection.

        Parameters
        ----------
        ax : :class:`matplotlib.axes.Axes`, optional
            Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None),
            this method will return a new axes object on return.
        show_name : bool, default: True
            Show name of all processes in the lifetime chart.
        bar_color : color, optional
            Bar color in lifetime chart.
        marker_color : color, default 'black'
            Color for read and write marker.
        marker_write : str, default 'x'
            Marker at write time in the lifetime chart.
        marker_read : str, default 'o'
            Marker at read time in the lifetime chart.
        show_markers : bool, default True
            Show markers at read and write times.

        Returns
        -------
            ax: Associated Matplotlib Axes (or array of Axes) object
        """

        # Set up the Axes object
        if ax is None:
            _, _ax = plt.subplots()
        else:
            _ax = ax

        # Lifetime chart left and right padding
        PAD_L, PAD_R = 0.05, 0.05
        max_execution_time = max(process.execution_time for process in self._collection)
        if max_execution_time > self._schedule_time:
            # Schedule time needs to be greater than or equal to the maximum process lifetime
            raise KeyError(
                f'Error: Schedule time: {self._schedule_time} < Max execution'
                f' time: {max_execution_time}'
            )

        # Generate the life-time chart
        for i, process in enumerate(_sorted_nicely(self._collection)):
            bar_start = process.start_time % self._schedule_time
            bar_end = process.start_time + process.execution_time
            bar_end = (
                bar_end
                if bar_end == self._schedule_time
                else bar_end % self._schedule_time
            )
            if show_markers:
                _ax.scatter(
                    x=bar_start,
                    y=i + 1,
                    marker=marker_write,
                    color=marker_color,
                    zorder=10,
                )
                _ax.scatter(
                    x=bar_end,
                    y=i + 1,
                    marker=marker_read,
                    color=marker_color,
                    zorder=10,
                )
            if bar_end >= bar_start:
                _ax.broken_barh(
                    [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)],
                    (i + 0.55, 0.9),
                    color=bar_color,
                )
            else:  # bar_end < bar_start
                _ax.broken_barh(
                    [
                        (
                            PAD_L + bar_start,
                            self._schedule_time - bar_start - PAD_L,
                        )
                    ],
                    (i + 0.55, 0.9),
                    color=bar_color,
                )
                _ax.broken_barh(
                    [(0, bar_end - PAD_R)], (i + 0.55, 0.9), color=bar_color
                )
            if show_name:
                _ax.annotate(
                    str(process),
                    (bar_start + PAD_L + 0.025, i + 1.00),
                    va="center",
                )
        _ax.grid(True)

        _ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        _ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        _ax.set_xlim(0, self._schedule_time)
        _ax.set_ylim(0.25, len(self._collection) + 0.75)
        return _ax

    def create_exclusion_graph_from_ports(
        self,
        read_ports: Optional[int] = None,
        write_ports: Optional[int] = None,
        total_ports: Optional[int] = None,
    ) -> nx.Graph:
        """
        Create an exclusion graph from a ProcessCollection based on a number of read/write ports

        Parameters
        ----------
        read_ports : int
            The number of read ports used when splitting process collection based on memory variable access.
        write_ports : int
            The number of write ports used when splitting process collection based on memory variable access.
        total_ports : int
            The total number of ports used when splitting process collection based on memory variable access.

        Returns
        -------
        nx.Graph

        """
        if total_ports is None:
            if read_ports is None or write_ports is None:
                raise ValueError(
                    "If total_ports is unset, both read_ports and write_ports"
                    " must be provided."
                )
            else:
                total_ports = read_ports + write_ports
        else:
            read_ports = total_ports if read_ports is None else read_ports
            write_ports = total_ports if write_ports is None else write_ports

        # Guard for proper read/write port settings
        if read_ports != 1 or write_ports != 1:
            raise ValueError(
                "Splitting with read and write ports not equal to one with the"
                " graph coloring heuristic does not make sense."
            )
        if total_ports not in (1, 2):
            raise ValueError(
                "Total ports should be either 1 (non-concurrent reads/writes)"
                " or 2 (concurrent read/writes) for graph coloring heuristic."
            )

        # Create new exclusion graph. Nodes are Processes
        exclusion_graph = nx.Graph()
        exclusion_graph.add_nodes_from(self._collection)
        for node1 in exclusion_graph:
            for node2 in exclusion_graph:
                if node1 == node2:
                    continue
                else:
                    node1_stop_time = node1.start_time + node1.execution_time
                    node2_stop_time = node2.start_time + node2.execution_time
                    if total_ports == 1:
                        # Single-port assignment
                        if node1.start_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1.start_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                    else:
                        # Dual-port assignment
                        if node1.start_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
        return exclusion_graph

    def create_exclusion_graph_from_execution_time(self) -> nx.Graph:
        """
        Generate exclusion graph based on processes overlapping in time

        Returns
        -------
            An nx.Graph exclusion graph where nodes are processes and arcs
            between two processes indicated overlap in time
        """
        exclusion_graph = nx.Graph()
        exclusion_graph.add_nodes_from(self._collection)
        for process1 in self._collection:
            for process2 in self._collection:
                if process1 == process2:
                    continue
                else:
                    t1 = set(
                        range(
                            process1.start_time,
                            process1.start_time + process1.execution_time,
                        )
                    )
                    t2 = set(
                        range(
                            process2.start_time,
                            process2.start_time + process2.execution_time,
                        )
                    )
                    if t1.intersection(t2):
                        exclusion_graph.add_edge(process1, process2)
        return exclusion_graph

    def split_execution_time(
        self, heuristic: str = "graph_color", coloring_strategy: str = "DSATUR"
    ) -> Set["ProcessCollection"]:
        """
        Split a ProcessCollection based on overlapping execution time.

        Parameters
        ----------
        heuristic : str, default: 'graph_color'
            The heuristic used when splitting based on execution times.
            One of: 'graph_color', 'left_edge'.
        coloring_strategy: str, default: 'DSATUR'
            Node ordering strategy passed to nx.coloring.greedy_color() if the heuristic is set to 'graph_color'. This
            parameter is only considered if heuristic is set to graph_color.
            One of
               * `'largest_first'`
               * `'random_sequential'`
               * `'smallest_last'`
               * `'independent_set'`
               * `'connected_sequential_bfs'`
               * `'connected_sequential_dfs'`
               * `'connected_sequential'` (alias for the previous strategy)
               * `'saturation_largest_first'`
               * `'DSATUR'` (alias for the saturation_largest_first strategy)

        Returns
        -------
        A set of new ProcessCollection objects with the process splitting.
        """
        if heuristic == "graph_color":
            exclusion_graph = self.create_exclusion_graph_from_execution_time()
            coloring = nx.coloring.greedy_color(
                exclusion_graph, strategy=coloring_strategy
            )
            return self._split_from_graph_coloring(coloring)
        elif heuristic == "left_edge":
            raise NotImplementedError()
        else:
            raise ValueError(f"Invalid heuristic '{heuristic}'")

    def split_ports(
        self,
        heuristic: str = "graph_color",
        read_ports: Optional[int] = None,
        write_ports: Optional[int] = None,
        total_ports: Optional[int] = None,
    ) -> Set["ProcessCollection"]:
        """
        Split this process storage based on some heuristic.

        Parameters
        ----------
        heuristic : str, default: "graph_color"
            The heuristic used when splitting this ProcessCollection.
            Valid options are:
                * "graph_color"
                * "..."
        read_ports : int, optional
            The number of read ports used when splitting process collection based on memory variable access.
        write_ports : int, optional
            The number of write ports used when splitting process collection based on memory variable access.
        total_ports : int, optional
            The total number of ports used when splitting process collection based on memory variable access.

        Returns
        -------
        A set of new ProcessCollection objects with the process splitting.
        """
        if total_ports is None:
            if read_ports is None or write_ports is None:
                raise ValueError(
                    "If total_ports is unset, both read_ports and write_ports"
                    " must be provided."
                )
            else:
                total_ports = read_ports + write_ports
        else:
            read_ports = total_ports if read_ports is None else read_ports
            write_ports = total_ports if write_ports is None else write_ports
        if heuristic == "graph_color":
            return self._split_ports_graph_color(read_ports, write_ports, total_ports)
        else:
            raise ValueError("Invalid heuristic provided.")

    def _split_ports_graph_color(
        self,
        read_ports: int,
        write_ports: int,
        total_ports: int,
        coloring_strategy: str = "DSATUR",
    ) -> Set["ProcessCollection"]:
        """
        Parameters
        ----------
        read_ports : int
            The number of read ports used when splitting process collection based on memory variable access.
        write_ports : int
            The number of write ports used when splitting process collection based on memory variable access.
        total_ports : int
            The total number of ports used when splitting process collection based on memory variable access.
        coloring_strategy: str, default: 'DSATUR'
            Node ordering strategy passed to nx.coloring.greedy_color()
            One of
               * `'largest_first'`
               * `'random_sequential'`
               * `'smallest_last'`
               * `'independent_set'`
               * `'connected_sequential_bfs'`
               * `'connected_sequential_dfs'`
               * `'connected_sequential'` (alias for the previous strategy)
               * `'saturation_largest_first'`
               * `'DSATUR'` (alias for the saturation_largest_first strategy)
        """
        # Create new exclusion graph. Nodes are Processes
        exclusion_graph = self.create_exclusion_graph_from_ports(
            read_ports, write_ports, total_ports
        )

        # Perform assignment from coloring and return result
        coloring = nx.coloring.greedy_color(exclusion_graph, strategy=coloring_strategy)
        return self._split_from_graph_coloring(coloring)

    def _split_from_graph_coloring(
        self,
        coloring: Dict[Process, int],
    ) -> Set["ProcessCollection"]:
        """
        Split :class:`Process` objects into a set of :class:`ProcessesCollection` objects based on a provided graph coloring.
        Resulting :class:`ProcessCollection` will have the same schedule time and cyclic propoery as self.

        Parameters
        ----------
        coloring : Dict[Process, int]
            Process->int (color) mappings

        Returns
        -------
        A set of new ProcessCollections.
        """
        process_collection_set_list = [set() for _ in range(max(coloring.values()) + 1)]
        for process, color in coloring.items():
            process_collection_set_list[color].add(process)
        return {
            ProcessCollection(process_collection_set, self._schedule_time, self._cyclic)
            for process_collection_set in process_collection_set_list
        }

    def _repr_svg_(self) -> str:
        """
        Generate an SVG_ of the resource collection. This is automatically displayed in e.g.
        Jupyter Qt console.
        """
        fig, ax = plt.subplots()
        self.draw_lifetime_chart(ax, show_markers=False)
        f = io.StringIO()
        fig.savefig(f, format="svg")

        return f.getvalue()

    def __repr__(self):
        return (
            f"ProcessCollection({self._collection}, {self._schedule_time},"
            f" {self._cyclic})"
        )