Skip to content
Snippets Groups Projects
resources.py 43.1 KiB
Newer Older
import io
import re
from functools import reduce
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
Mikael Henriksson's avatar
Mikael Henriksson committed

import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.axes import Axes
from matplotlib.ticker import MaxNLocator

from b_asic._preferences import LATENCY_COLOR
from b_asic.process import MemoryVariable, PlainMemoryVariable, Process
Mikael Henriksson's avatar
Mikael Henriksson committed

# Default latency coloring RGB tuple
_LATENCY_COLOR = tuple(c / 255 for c in LATENCY_COLOR)

#
# Human-intuitive sorting:
# https://stackoverflow.com/questions/2669059/how-to-sort-alpha-numeric-set-in-python
#
# Typing '_T' to help Pyright propagate type-information
#
_T = TypeVar('_T')
Mikael Henriksson's avatar
Mikael Henriksson committed


def _sorted_nicely(to_be_sorted: Iterable[_T]) -> List[_T]:
    """Sort the given iterable in the way that humans expect."""

    def convert(text):
        return int(text) if text.isdigit() else text

    def alphanum_key(key):
        return [convert(c) for c in re.split('([0-9]+)', str(key))]

    return sorted(to_be_sorted, key=alphanum_key)


def _sanitize_port_option(
    read_ports: Optional[int] = None,
    write_ports: Optional[int] = None,
    total_ports: Optional[int] = None,
) -> Tuple[int, int, int]:
    """
    General port sanitization function used to test if a port specification makes sense.
    Raises ValueError if the port specification is in-proper.

    Parameters
    ----------
    read_ports : int, optional
        The number of read ports.
    write_ports : int, optional
        The number of write ports.
    total_ports : int, optional
        The total number of ports

    Returns
    -------
    Returns a triple int tuple (read_ports, write_ports, total_ports) equal to the input, or sanitized if one of the input equals None.
    If total_ports is set to None at the input, it is set to read_ports+write_ports at the output.
    If read_ports or write_ports is set to None at the input, it is set to total_ports at the output.

    """
    if total_ports is None:
        if read_ports is None or write_ports is None:
            raise ValueError(
                "If total_ports is unset, both read_ports and write_ports"
                " must be provided."
            )
        else:
            total_ports = read_ports + write_ports
    else:
        read_ports = total_ports if read_ports is None else read_ports
        write_ports = total_ports if write_ports is None else write_ports
    if total_ports < read_ports:
        raise ValueError(
            f'Total ports ({total_ports}) less then read ports ({read_ports})'
        )
    if total_ports < write_ports:
        raise ValueError(
            f'Total ports ({total_ports}) less then write ports ({write_ports})'
        )
    return (read_ports, write_ports, total_ports)


Mikael Henriksson's avatar
Mikael Henriksson committed
def draw_exclusion_graph_coloring(
    exclusion_graph: nx.Graph,
    color_dict: Dict[Process, int],
    ax: Optional[Axes] = None,
    color_list: Optional[Union[List[str], List[Tuple[float, float, float]]]] = None,
) -> None:
Mikael Henriksson's avatar
Mikael Henriksson committed
    """
    Draw a colored exclusion graph from the memory assignment.
Mikael Henriksson's avatar
Mikael Henriksson committed

    .. code-block:: python

        _, ax = plt.subplots(1, 1)
        collection = ProcessCollection(...)
        exclusion_graph = collection.create_exclusion_graph_from_overlap()
        color_dict = nx.greedy_color(exclusion_graph)
        draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[0])
        plt.show()

    Parameters
    ----------
    exclusion_graph : nx.Graph
        A nx.Graph exclusion graph object that is to be drawn.
    color_dict : dict
        A dict where keys are :class:`~b_asic.process.Process` objects and values are
        integers representing colors. These dictionaries are automatically generated by
        :func:`networkx.algorithms.coloring.greedy_color`.
Mikael Henriksson's avatar
Mikael Henriksson committed
    ax : :class:`matplotlib.axes.Axes`, optional
        A Matplotlib :class:`~matplotlib.axes.Axes` object to draw the exclusion graph.
    color_list : iterable of color, optional
        A list of colors in Matplotlib format.

    Returns
    -------
    None
Mikael Henriksson's avatar
Mikael Henriksson committed
    """
    COLOR_LIST = [
        '#aa0000',
        '#00aa00',
        '#0000ff',
        '#ff00aa',
        '#ffaa00',
Mikael Henriksson's avatar
Mikael Henriksson committed
        '#00ffaa',
        '#aaff00',
        '#aa00ff',
        '#00aaff',
        '#ff0000',
        '#00ff00',
        '#0000aa',
        '#aaaa00',
        '#aa00aa',
        '#00aaaa',
Mikael Henriksson's avatar
Mikael Henriksson committed
    ]
    if color_list is None:
        node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()}
    else:
        node_color_dict = {k: color_list[v] for k, v in color_dict.items()}
    node_color_list = [node_color_dict[node] for node in exclusion_graph]
    nx.draw_networkx(
        exclusion_graph,
        node_color=node_color_list,
        ax=ax,
        pos=nx.spring_layout(exclusion_graph, seed=1),
    )


class _ForwardBackwardEntry:
    def __init__(
        self,
        inputs: Optional[List[Process]] = None,
        outputs: Optional[List[Process]] = None,
        regs: Optional[List[Optional[Process]]] = None,
        back_edge_to: Optional[Dict[int, int]] = None,
        back_edge_from: Optional[Dict[int, int]] = None,
        outputs_from: Optional[int] = None,
    ):
        """
        Single entry in a _ForwardBackwardTable. Aggregate type of input, output and list of registers.

        Parameters
        ----------
        inputs : List[Process], optional
            input
        outputs : List[Process], optional
            output
        regs : List[Optional[Process]], optional
            regs
        back_edge_to : dict, optional
            Dictionary containing back edges of this entry to registers in the next entry.
        back_edge_from : dict, optional
            Dictionary containing the back edge of the previous entry to registers in this entry.
        outputs_from : int, optional
        """
        self.inputs: List[Process] = [] if inputs is None else inputs
        self.outputs: List[Process] = [] if outputs is None else outputs
        self.regs: List[Optional[Process]] = [] if regs is None else regs
        self.back_edge_to: Dict[int, int] = {} if back_edge_to is None else back_edge_to
        self.back_edge_from: Dict[int, int] = (
            {} if back_edge_from is None else back_edge_from
        )
        self.outputs_from = outputs_from


class _ForwardBackwardTable:
    def __init__(self, collection: 'ProcessCollection'):
        """
        Forward-Backward allocation table for ProcessCollections. This structure implements the forward-backward
        register allocation algorithm, which is used to generate hardware from MemoryVariables in a ProcessCollection.

        Parameters
        ----------
        collection : ProcessCollection
            ProcessCollection to apply forward-backward allocation on
        """
        # Generate an alive variable list
        self._collection = collection
        self._live_variables: List[int] = [0] * collection._schedule_time
        for mv in self._collection:
            stop_time = mv.start_time + mv.execution_time
            for alive_time in range(mv.start_time, stop_time):
                self._live_variables[alive_time % collection._schedule_time] += 1

        # First, create an empty forward-backward table with the right dimensions
        self.table: List[_ForwardBackwardEntry] = []
        for _ in range(collection.schedule_time):
            entry = _ForwardBackwardEntry()
            # https://github.com/microsoft/pyright/issues/1073
            for _ in range(max(self._live_variables)):
                entry.regs.append(None)
            self.table.append(entry)

        # Insert all processes (one per time-slot) to the table input
        # TODO: "Input each variable at the time step corresponding to the beginning of its lifetime. If multiple
        #        variables are input in a given cycle, theses are allocated to multple registers such that the variable
        #        with the longest lifetime is allocated to the inital register and the other variables are allocated to
        #        consecutive registers in decreasing order of lifetime." -- K. Parhi
        for mv in collection:
            self.table[mv.start_time].inputs.append(mv)
            if mv.execution_time:
                self.table[(mv.start_time + 1) % collection.schedule_time].regs[0] = mv
            else:
                self.table[mv.start_time].outputs.append(mv)
                self.table[mv.start_time].outputs_from = -1

        # Forward-backward allocation
        forward = True
        while not self._forward_backward_is_complete():
            if forward:
                self._do_forward_allocation()
            else:
                self._do_single_backward_allocation()
            forward = not (forward)

    def _forward_backward_is_complete(self) -> bool:
        s = {proc for e in self.table for proc in e.outputs}
        return len(self._collection._collection - s) == 0

    def _do_forward_allocation(self):
        """
        Forward all Processes as far as possible in the register chain. Processes are forwarded until they reach their
        end time (at which they are added to the output list), or until they reach the end of the register chain.
        """
        rows = len(self.table)
        cols = len(self.table[0].regs)
        # Note that two passes of the forward allocation need to be done, since variables may loop around the schedule
        # cycle boundary.
        for _ in range(2):
            for time, entry in enumerate(self.table):
                for reg_idx, reg in enumerate(entry.regs):
                    if reg is not None:
                        reg_end_time = (reg.start_time + reg.execution_time) % rows
                        if reg_end_time == time:
                            if reg not in self.table[time].outputs:
                                self.table[time].outputs.append(reg)
                                self.table[time].outputs_from = reg_idx
                        elif reg_idx != cols - 1:
                            next_row = (time + 1) % rows
                            next_col = reg_idx + 1
                            if self.table[next_row].regs[next_col] not in (None, reg):
                                cell = self.table[next_row].regs[next_col]
                                raise ValueError(
                                    f'Can\'t forward allocate {reg} in row={time},'
                                    f' col={reg_idx} to next_row={next_row},'
                                    f' next_col={next_col} (cell contains: {cell})'
                                )
                            else:
                                self.table[(time + 1) % rows].regs[reg_idx + 1] = reg

    def _do_single_backward_allocation(self):
        """
        Perform backward allocation of Processes in the allocation table.
        """
        rows = len(self.table)
        cols = len(self.table[0].regs)
        outputs = {out for e in self.table for out in e.outputs}
        #
        # Pass #1: Find any (one) non-dead variable from the last register and try to backward allocate it to a
        # previous register where it is not blocking an open path. This heuristic helps minimize forward allocation
        # moves later.
        #
        for time, entry in enumerate(self.table):
            reg = entry.regs[-1]
            if reg is not None and reg not in outputs:
                next_entry = self.table[(time + 1) % rows]
                for nreg_idx, nreg in enumerate(next_entry.regs):
                    if nreg is None and (
                        nreg_idx == 0 or entry.regs[nreg_idx - 1] is not None
                    ):
                        next_entry.regs[nreg_idx] = reg
                        entry.back_edge_to[cols - 1] = nreg_idx
                        next_entry.back_edge_from[nreg_idx] = cols - 1
                        return
        #
        # Pass #2: Backward allocate the first non-dead variable from the last registers to an empty register.
        #
        for time, entry in enumerate(self.table):
            reg = entry.regs[-1]
            if reg is not None and reg not in outputs:
                next_entry = self.table[(time + 1) % rows]
                for nreg_idx, nreg in enumerate(next_entry.regs):
                    if nreg is None:
                        next_entry.regs[nreg_idx] = reg
                        entry.back_edge_to[cols - 1] = nreg_idx
                        next_entry.back_edge_from[nreg_idx] = cols - 1
                        return

        # All passes failed, raise exception...
        raise ValueError(
Oscar Gustafsson's avatar
Oscar Gustafsson committed
            "Can't backward allocate any variable. This should not happen."
        )

    def __getitem__(self, key):
        return self.table[key]

    def __iter__(self):
        yield from self.table

    def __len__(self):
        return len(self.table)

    def __str__(self):
        # Text width of input and output column
Oscar Gustafsson's avatar
Oscar Gustafsson committed
        def lst_w(proc_lst):
            return reduce(lambda n, p: n + len(str(p)) + 1, proc_lst, 0)

        input_col_w = max(5, max(lst_w(pl.inputs) for pl in self.table) + 1)
        output_col_w = max(5, max(lst_w(pl.outputs) for pl in self.table) + 1)

        # Text width of register columns
        reg_col_w = 0
        for entry in self.table:
            for reg in entry.regs:
                reg_col_w = max(len(str(reg)), reg_col_w)
        reg_col_w = max(4, reg_col_w + 2)

        # Header row of the string
        res = f' T |{"In":^{input_col_w}}|'
        for i in range(max(self._live_variables)):
            reg = f'R{i}'
            res += f'{reg:^{reg_col_w}}|'
        res += f'{"Out":^{output_col_w}}|'
        res += '\n'
        res += (
            6 + input_col_w + (reg_col_w + 1) * max(self._live_variables) + output_col_w
        ) * '-' + '\n'

        for time, entry in enumerate(self.table):
            # Time
            res += f'{time:^3}| '

            # Input column
            inputs_str = ''
            for input in entry.inputs:
                inputs_str += input.name + ','
            if inputs_str:
                inputs_str = inputs_str[:-1]
            res += f'{inputs_str:^{input_col_w-1}}|'

            # Register columns
            GREEN_BACKGROUND_ANSI = "\u001b[42m"
            BROWN_BACKGROUND_ANSI = "\u001b[43m"
            RESET_BACKGROUND_ANSI = "\033[0m"
            for reg_idx, reg in enumerate(entry.regs):
                if reg is None:
                    res += " " * reg_col_w + "|"
                else:
                    if reg_idx in entry.back_edge_to:
                        res += f'{GREEN_BACKGROUND_ANSI}'
                        res += f'{reg.name:^{reg_col_w}}'
                        res += f'{RESET_BACKGROUND_ANSI}|'
                    elif reg_idx in entry.back_edge_from:
                        res += f'{BROWN_BACKGROUND_ANSI}'
                        res += f'{reg.name:^{reg_col_w}}'
                        res += f'{RESET_BACKGROUND_ANSI}|'
                    else:
                        res += f'{reg.name:^{reg_col_w}}' + "|"

            # Output column
            outputs_str = ''
            for output in entry.outputs:
                outputs_str += output.name + ','
            if outputs_str:
                outputs_str = outputs_str[:-1]
            if entry.outputs_from is not None:
                outputs_str += f"({entry.outputs_from})"
            res += f'{outputs_str:^{output_col_w}}|'

            res += '\n'
        return res


Mikael Henriksson's avatar
Mikael Henriksson committed
class ProcessCollection:
    """
    Collection of one or more processes

    Parameters
    ----------
    collection : set of :class:`~b_asic.process.Process` objects
        The Process objects forming this ProcessCollection.
        Length of the time-axis in the generated graph.
    cyclic : bool, default: False
        If the processes operates cyclically, i.e., if time 0 == time *schedule_time*.
Mikael Henriksson's avatar
Mikael Henriksson committed
    """

    def __init__(
        self,
        collection: Set[Process],
        schedule_time: int,
        cyclic: bool = False,
    ):
        self._collection = collection
        self._schedule_time = schedule_time
        self._cyclic = cyclic
Mikael Henriksson's avatar
Mikael Henriksson committed

    def collection(self) -> Set[Process]:
        return self._collection

    @property
    def schedule_time(self) -> int:
        return self._schedule_time

    def __len__(self):
Oscar Gustafsson's avatar
Oscar Gustafsson committed
        return len(self._collection)
Mikael Henriksson's avatar
Mikael Henriksson committed
    def add_process(self, process: Process):
        """
        Add a new process to this process collection.

        Parameters
        ----------
        process : Process
            The process object to be added to the collection.
Mikael Henriksson's avatar
Mikael Henriksson committed
        """
        self._collection.add(process)

Oscar Gustafsson's avatar
Oscar Gustafsson committed
    def plot(
Mikael Henriksson's avatar
Mikael Henriksson committed
        self,
        ax: Optional[Axes] = None,
        show_name: bool = True,
        bar_color: Union[str, Tuple[float, ...]] = _LATENCY_COLOR,
        marker_color: Union[str, Tuple[float, ...]] = "black",
        marker_read: str = "X",
        marker_write: str = "o",
        show_markers: bool = True,
Mikael Henriksson's avatar
Mikael Henriksson committed
    ):
        """
        Plot a process variable lifetime chart.
Mikael Henriksson's avatar
Mikael Henriksson committed

        Parameters
        ----------
        ax : :class:`matplotlib.axes.Axes`, optional
            Matplotlib :class:`~matplotlib.axes.Axes` object to draw this lifetime chart
            onto. If not provided (i.e., set to None), this method will return a new
            Axes object.
Mikael Henriksson's avatar
Mikael Henriksson committed
        show_name : bool, default: True
            Show name of all processes in the lifetime chart.
        bar_color : color, optional
            Bar color in lifetime chart.
        marker_color : color, default 'black'
            Color for read and write marker.
        marker_write : str, default 'x'
            Marker at write time in the lifetime chart.
        marker_read : str, default 'o'
            Marker at read time in the lifetime chart.
        show_markers : bool, default True
            Show markers at read and write times.
        row : int, optional
            Render all processes in this collection on a specified row in the matplotlib axes object.
            Defaults to None, which renders all processes on separate rows. This option is useful when
            drawing cell assignments.
Mikael Henriksson's avatar
Mikael Henriksson committed

        Returns
        -------
        ax : Associated Matplotlib Axes (or array of Axes) object
Mikael Henriksson's avatar
Mikael Henriksson committed
        """

        # Set up the Axes object
Mikael Henriksson's avatar
Mikael Henriksson committed
        if ax is None:
            _, _ax = plt.subplots()
        else:
            _ax = ax

Mikael Henriksson's avatar
Mikael Henriksson committed
        PAD_L, PAD_R = 0.05, 0.05
        max_execution_time = max(process.execution_time for process in self._collection)
        if max_execution_time > self._schedule_time:
            # Schedule time needs to be greater than or equal to the maximum process
            # lifetime
Mikael Henriksson's avatar
Mikael Henriksson committed
            raise KeyError(
                f'Error: Schedule time: {self._schedule_time} < Max execution'
                f' time: {max_execution_time}'
Mikael Henriksson's avatar
Mikael Henriksson committed
            )
        for i, process in enumerate(_sorted_nicely(self._collection)):
Oscar Gustafsson's avatar
Oscar Gustafsson committed
            bar_row = i if row is None else row
            bar_start = process.start_time % self._schedule_time
            bar_end = process.start_time + process.execution_time
Mikael Henriksson's avatar
Mikael Henriksson committed
            bar_end = (
                bar_end
                if bar_end == self._schedule_time
                else bar_end % self._schedule_time
            )
                _ax.scatter(  # type: ignore
                    marker=marker_write,
                    color=marker_color,
                    zorder=10,
                )
                _ax.scatter(  # type: ignore
                    marker=marker_read,
                    color=marker_color,
                    zorder=10,
                )
            if bar_end >= bar_start:
                _ax.broken_barh(  # type: ignore
Mikael Henriksson's avatar
Mikael Henriksson committed
                    [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)],
Mikael Henriksson's avatar
Mikael Henriksson committed
                )
            else:  # bar_end < bar_start
                _ax.broken_barh(  # type: ignore
                    [
                        (
                            PAD_L + bar_start,
                            self._schedule_time - bar_start - PAD_L,
                        )
                    ],
                _ax.broken_barh(  # type: ignore
                    [(0, bar_end - PAD_R)], (bar_row + 0.55, 0.9), color=bar_color
Mikael Henriksson's avatar
Mikael Henriksson committed
            if show_name:
                _ax.annotate(  # type: ignore
Mikael Henriksson's avatar
Mikael Henriksson committed
                    str(process),
                    (bar_start + PAD_L + 0.025, bar_row + 1.00),
Mikael Henriksson's avatar
Mikael Henriksson committed
                    va="center",
                )
        _ax.grid(True)  # type: ignore
Mikael Henriksson's avatar
Mikael Henriksson committed

        _ax.xaxis.set_major_locator(MaxNLocator(integer=True))  # type: ignore
        _ax.yaxis.set_major_locator(MaxNLocator(integer=True))  # type: ignore
        _ax.set_xlim(0, self._schedule_time)  # type: ignore
Oscar Gustafsson's avatar
Oscar Gustafsson committed
        if row is None:
            _ax.set_ylim(0.25, len(self._collection) + 0.75)  # type: ignore
        else:
            pass
Mikael Henriksson's avatar
Mikael Henriksson committed
        return _ax

    def create_exclusion_graph_from_ports(
        self,
        read_ports: Optional[int] = None,
        write_ports: Optional[int] = None,
        total_ports: Optional[int] = None,
Mikael Henriksson's avatar
Mikael Henriksson committed
    ) -> nx.Graph:
        """
        Create an exclusion graph based on a number of read/write ports.
Mikael Henriksson's avatar
Mikael Henriksson committed

        Parameters
        ----------
            The number of read ports used when splitting process collection based on
            memory variable access.
            The number of write ports used when splitting process collection based on
            memory variable access.
            The total number of ports used when splitting process collection based on
            memory variable access.
        read_ports, write_ports, total_ports = _sanitize_port_option(
            read_ports, write_ports, total_ports
        )

        # Guard for proper read/write port settings
        if read_ports != 1 or write_ports != 1:
            raise ValueError(
                "Splitting with read and write ports not equal to one with the"
                " graph coloring heuristic does not make sense."
            )
        if total_ports not in (1, 2):
            raise ValueError(
                "Total ports should be either 1 (non-concurrent reads/writes)"
                " or 2 (concurrent read/writes) for graph coloring heuristic."
            )

        # Create new exclusion graph. Nodes are Processes
        exclusion_graph = nx.Graph()
        exclusion_graph.add_nodes_from(self._collection)
        for node1 in exclusion_graph:
            for node2 in exclusion_graph:
                if node1 == node2:
                    continue
                else:
                    node1_stop_time = node1.start_time + node1.execution_time
                    node2_stop_time = node2.start_time + node2.execution_time
                    if total_ports == 1:
                        # Single-port assignment
                        if node1.start_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1.start_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                    else:
                        # Dual-port assignment
                        if node1.start_time == node2.start_time:
                            exclusion_graph.add_edge(node1, node2)
                        elif node1_stop_time == node2_stop_time:
                            exclusion_graph.add_edge(node1, node2)
        return exclusion_graph

    def create_exclusion_graph_from_execution_time(self) -> nx.Graph:
        """
        Generate exclusion graph based on processes overlapping in time
Mikael Henriksson's avatar
Mikael Henriksson committed

        Returns
        -------
        An nx.Graph exclusion graph where nodes are processes and arcs
        between two processes indicated overlap in time
Mikael Henriksson's avatar
Mikael Henriksson committed
        """
        exclusion_graph = nx.Graph()
        exclusion_graph.add_nodes_from(self._collection)
        for process1 in self._collection:
            for process2 in self._collection:
                if process1 == process2:
                    continue
                else:
                    t1 = set(
                        range(
                            process1.start_time,
                            min(
                                process1.start_time + process1.execution_time,
                                self._schedule_time,
                            ),
                        )
                    ).union(
                        set(
                            range(
                                0,
                                process1.start_time
                                + process1.execution_time
                                - self._schedule_time,
                            )
Mikael Henriksson's avatar
Mikael Henriksson committed
                        )
                    )
                    t2 = set(
                        range(
                            process2.start_time,
                            min(
                                process2.start_time + process2.execution_time,
                                self._schedule_time,
                            ),
                        )
                    ).union(
                        set(
                            range(
                                0,
                                process2.start_time
                                + process2.execution_time
                                - self._schedule_time,
                            )
Mikael Henriksson's avatar
Mikael Henriksson committed
                        )
                    )
                    if t1.intersection(t2):
                        exclusion_graph.add_edge(process1, process2)
        return exclusion_graph

        self,
        heuristic: str = "graph_color",
        coloring_strategy: str = "saturation_largest_first",
    ) -> Set["ProcessCollection"]:
        """
        Split a ProcessCollection based on overlapping execution time.

        Parameters
        ----------
        heuristic : {'graph_color', 'left_edge'}, default: 'graph_color'
            The heuristic used when splitting based on execution times.
        coloring_strategy : str, default: 'saturation_largest_first'
            Node ordering strategy passed to :func:`networkx.coloring.greedy_color`.
            This parameter is only considered if *heuristic* is set to 'graph_color'.
            * 'largest_first'
            * 'random_sequential'
            * 'smallest_last'
            * 'independent_set'
            * 'connected_sequential_bfs'
            * 'connected_sequential_dfs' or 'connected_sequential'
            * 'saturation_largest_first' or 'DSATUR'

        Returns
        -------
        A set of new ProcessCollection objects with the process splitting.
        """
        if heuristic == "graph_color":
            exclusion_graph = self.create_exclusion_graph_from_execution_time()
            coloring = nx.coloring.greedy_color(
                exclusion_graph, strategy=coloring_strategy
            )
            return self._split_from_graph_coloring(coloring)
        elif heuristic == "left_edge":
            raise NotImplementedError()
        else:
            raise ValueError(f"Invalid heuristic '{heuristic}'")

    def split_ports(
Mikael Henriksson's avatar
Mikael Henriksson committed
        self,
        heuristic: str = "graph_color",
        read_ports: Optional[int] = None,
        write_ports: Optional[int] = None,
        total_ports: Optional[int] = None,
    ) -> Set["ProcessCollection"]:
        """
        Split this process storage based on concurrent read/write times according to some heuristic.
Mikael Henriksson's avatar
Mikael Henriksson committed

        Parameters
        ----------
        heuristic : str, default: "graph_color"
            The heuristic used when splitting this ProcessCollection.
Mikael Henriksson's avatar
Mikael Henriksson committed
            Valid options are:
            * "graph_color"
            * "..."
Mikael Henriksson's avatar
Mikael Henriksson committed
        read_ports : int, optional
            The number of read ports used when splitting process collection based on
            memory variable access.
Mikael Henriksson's avatar
Mikael Henriksson committed
        write_ports : int, optional
            The number of write ports used when splitting process collection based on
            memory variable access.
Mikael Henriksson's avatar
Mikael Henriksson committed
        total_ports : int, optional
            The total number of ports used when splitting process collection based on
            memory variable access.
Mikael Henriksson's avatar
Mikael Henriksson committed

        Returns
        -------
        A set of new ProcessCollection objects with the process splitting.
Mikael Henriksson's avatar
Mikael Henriksson committed
        """
        read_ports, write_ports, total_ports = _sanitize_port_option(
            read_ports, write_ports, total_ports
        )
Mikael Henriksson's avatar
Mikael Henriksson committed
        if heuristic == "graph_color":
            return self._split_ports_graph_color(read_ports, write_ports, total_ports)
Mikael Henriksson's avatar
Mikael Henriksson committed
        else:
            raise ValueError("Invalid heuristic provided.")
Mikael Henriksson's avatar
Mikael Henriksson committed

    def _split_ports_graph_color(
        self,
        read_ports: int,
        write_ports: int,
        total_ports: int,
        coloring_strategy: str = "saturation_largest_first",
Mikael Henriksson's avatar
Mikael Henriksson committed
    ) -> Set["ProcessCollection"]:
        """
        Parameters
        ----------
            The number of read ports used when splitting process collection based on
            memory variable access.
            The number of write ports used when splitting process collection based on
            memory variable access.
            The total number of ports used when splitting process collection based on
            memory variable access.
        coloring_strategy : str, default: 'saturation_largest_first'
            Node ordering strategy passed to :func:`networkx.coloring.greedy_color`
            * 'largest_first'
            * 'random_sequential'
            * 'smallest_last'
            * 'independent_set'
            * 'connected_sequential_bfs'
            * 'connected_sequential_dfs' or 'connected_sequential'
            * 'saturation_largest_first' or 'DSATUR'
Mikael Henriksson's avatar
Mikael Henriksson committed
        """
        # Create new exclusion graph. Nodes are Processes
        exclusion_graph = self.create_exclusion_graph_from_ports(
            read_ports, write_ports, total_ports
        )
Mikael Henriksson's avatar
Mikael Henriksson committed

        # Perform assignment from coloring and return result
        coloring = nx.coloring.greedy_color(exclusion_graph, strategy=coloring_strategy)
        return self._split_from_graph_coloring(coloring)

    def _split_from_graph_coloring(
        self,
        coloring: Dict[Process, int],
    ) -> Set["ProcessCollection"]:
        """
        Split :class:`Process` objects into a set of :class:`ProcessesCollection` objects based on a provided graph coloring.

        Resulting :class:`ProcessCollection` will have the same schedule time and cyclic
        property as self.
        coloring : dict
Mikael Henriksson's avatar
Mikael Henriksson committed

        process_collection_set_list = [set() for _ in range(max(coloring.values()) + 1)]
Mikael Henriksson's avatar
Mikael Henriksson committed
        for process, color in coloring.items():
            process_collection_set_list[color].add(process)
Mikael Henriksson's avatar
Mikael Henriksson committed
        return {
            ProcessCollection(process_collection_set, self._schedule_time, self._cyclic)
            for process_collection_set in process_collection_set_list
Mikael Henriksson's avatar
Mikael Henriksson committed
        }

    def _repr_svg_(self) -> str:
        """
        Generate an SVG_ of the resource collection. This is automatically displayed in
        e.g. Jupyter Qt console.
        """
        fig, ax = plt.subplots()
        self.plot(ax=ax, show_markers=False)
        f = io.StringIO()
        fig.savefig(f, format="svg")  # type: ignore
        return f.getvalue()

    def __repr__(self):
        return (
            f"ProcessCollection({self._collection}, {self._schedule_time},"
            f" {self._cyclic})"
        )

    def __iter__(self):
        return iter(self._collection)

    def graph_color_cell_assignment(
        self,
        coloring_strategy: str = "saturation_largest_first",
        *,
        coloring: Optional[Dict[Process, int]] = None,
    ) -> Set["ProcessCollection"]:
        """
        Perform cell assignment of the processes in this collection using graph coloring with networkx.coloring.greedy_color.
        Two or more processes can share a single cell if, and only if, they have no overlaping time alive.
        coloring_strategy : str, default: "saturation_largest_first"
            Graph coloring strategy passed to networkx.coloring.greedy_color().
        coloring : dictionary, optional
            An optional graph coloring, dictionary with Process and its associated color (int).
            If a graph coloring is not provided throught this parameter, one will be created when calling this method.
        A set of ProcessCollection

        """
        cell_assignment: Dict[int, ProcessCollection] = dict()
        exclusion_graph = self.create_exclusion_graph_from_execution_time()
        if coloring is None:
            coloring = nx.coloring.greedy_color(
                exclusion_graph, strategy=coloring_strategy
            )
        for process, cell in coloring.items():
            if cell not in cell_assignment:
                cell_assignment[cell] = ProcessCollection(set(), self._schedule_time)
                cell_assignment[cell].add_process(process)
            else:
                cell_assignment[cell].add_process(process)
        return set(cell_assignment.values())

    def left_edge_cell_assignment(self) -> Dict[int, "ProcessCollection"]:
        """
        Perform cell assignment of the processes in this collection using the left-edge algorithm.
        Two or more processes can share a single cell if, and only if, they have no overlaping time alive.
        """
        next_empty_cell = 0
        cell_assignment: Dict[int, ProcessCollection] = dict()
        for next_process in sorted(self):
            insert_to_new_cell = True
            for cell in cell_assignment:
                insert_to_this_cell = True
                for process in cell_assignment[cell]:
                    next_process_stop_time = (
                        next_process.start_time + next_process.execution_time
                    ) % self._schedule_time
                    if (
                        next_process.start_time
                        < process.start_time + process.execution_time
                        or next_process_stop_time < next_process.start_time
                        and next_process_stop_time > process.start_time
                    ):
                        insert_to_this_cell = False
                        break
                if insert_to_this_cell:
                    cell_assignment[cell].add_process(next_process)
                    insert_to_new_cell = False
                    break
            if insert_to_new_cell:
                cell_assignment[next_empty_cell] = ProcessCollection(
                    collection=set(), schedule_time=self._schedule_time
                )
                cell_assignment[next_empty_cell].add_process(next_process)
                next_empty_cell += 1
        return cell_assignment

    def generate_memory_based_storage_vhdl(
        self,
        filename: str,
        word_length: int,
        assignment: Set['ProcessCollection'],
        read_ports: int = 1,
        write_ports: int = 1,
        total_ports: int = 2,
    ):
        """
        Generate VHDL code for memory based storage of processes (MemoryVariables).

        Parameters
        ----------
        filename : str
            Filename of output file.
        entity_name : str
            Name used for the VHDL entity.
        word_length : int
            Word length of the memory variable objects.
            A possible cell assignment to use when generating the memory based storage.
            The cell assignment is a dictionary int to ProcessCollection where the integer
            corresponds to the cell to assign all MemoryVariables in corresponding process
            collection.
            If unset, each MemoryVariable will be assigned to a unique single cell.
        read_ports : int, default: 1
            The number of read ports used when splitting process collection based on
            memory variable access. If total ports in unset, this parameter has to be set
            and total_ports is assumed to be read_ports + write_ports.
        write_ports : int, default: 1
            The number of write ports used when splitting process collection based on
            memory variable access. If total ports is unset, this parameter has to be set
            and total_ports is assumed to be read_ports + write_ports.
        total_ports : int, default: 2
            The total number of ports used when splitting process collection based on
            memory variable access.
        input_sync : bool, default: True
            Add registers to the input signals (enable signal and data input signals).
            Adding registers to the inputs allow pipelining of address generation (which is added automatically).
            For large interleavers, this can improve timing significantly.
        # Check that this is a ProcessCollection of (Plain)MemoryVariables
        is_memory_variable = all(
            isinstance(process, MemoryVariable) for process in self._collection
        )
        is_plain_memory_variable = all(
            isinstance(process, PlainMemoryVariable) for process in self._collection
        )
        if not (is_memory_variable or is_plain_memory_variable):
            raise ValueError(
                "HDL can only be generated for ProcessCollection of"
                " (Plain)MemoryVariables"
            )

        # Sanitize port settings
        read_ports, write_ports, total_ports = _sanitize_port_option(
            read_ports, write_ports, total_ports
        )

        # Make sure the provided assignment (Set[ProcessCollection]) only
        # contains memory variables from this (self).
        for collection in assignment:
            for mv in collection:
                if mv not in self:
                    raise ValueError(
                        f'{mv.__repr__()} is not part of {self.__repr__()}.'
                    )

        # Make sure that concurrent reads/writes do not surpass the port setting
        for mv in self:
Oscar Gustafsson's avatar
Oscar Gustafsson committed

            def filter_write(p):
                return p.start_time == mv.start_time

            def filter_read(p):
                return (
                    (p.start_time + p.execution_time) % self._schedule_time
                    == mv.start_time + mv.execution_time % self._schedule_time
                )

            needed_write_ports = len(list(filter(filter_write, self)))
            needed_read_ports = len(list(filter(filter_read, self)))
            if needed_write_ports > write_ports + 1:
                    f'More than {write_ports} write ports needed ({needed_write_ports})'
                    ' to generate HDL for this ProcessCollection'
            if needed_read_ports > read_ports + 1:
                    f'More than {read_ports} read ports needed ({needed_read_ports}) to'
                    ' generate HDL for this ProcessCollection'
        with open(filename, 'w') as f:
Oscar Gustafsson's avatar
Oscar Gustafsson committed
            from b_asic.codegen.vhdl import architecture, common, entity
Oscar Gustafsson's avatar
Oscar Gustafsson committed
            common.write_b_asic_vhdl_preamble(f)
            common.write_ieee_header(f)
            entity.write_memory_based_storage(
                f, entity_name=entity_name, collection=self, word_length=word_length
Oscar Gustafsson's avatar
Oscar Gustafsson committed
            architecture.write_memory_based_storage(
                entity_name=entity_name,
                word_length=word_length,
                read_ports=read_ports,
                write_ports=write_ports,
                total_ports=total_ports,
            )

    def generate_register_based_storage_vhdl(
        self,
        filename: str,
        word_length: int,
        entity_name: str,
        read_ports: int = 1,
        write_ports: int = 1,
        total_ports: int = 2,
        Generate VHDL code for register based storages of processes based on Forward-Backward Register Allocation [1].

        [1]: K. Parhi: VLSI Digital Signal Processing Systems: Design and Implementation, Ch. 6.3.2

        Parameters
        ----------
        filename : str
            Filename of output file.
            Word length of the memory variable objects.
        entity_name : str
            Name used for the VHDL entity.
        read_ports : int, default: 1
            The number of read ports used when splitting process collection based on
            memory variable access. If total ports in unset, this parameter has to be set
            and total_ports is assumed to be read_ports + write_ports.
        write_ports : int, default: 1
            The number of write ports used when splitting process collection based on
            memory variable access. If total ports is unset, this parameter has to be set
            and total_ports is assumed to be read_ports + write_ports.
        total_ports : int, default: 2
            The total number of ports used when splitting process collection based on
            memory variable access.
        """
        # Check that this is a ProcessCollection of (Plain)MemoryVariables
        is_memory_variable = all(
            isinstance(process, MemoryVariable) for process in self._collection
        )
        is_plain_memory_variable = all(
            isinstance(process, PlainMemoryVariable) for process in self._collection
        )
        if not (is_memory_variable or is_plain_memory_variable):
            raise ValueError(
                "HDL can only be generated for ProcessCollection of"
                " (Plain)MemoryVariables"
            )

        # Sanitize port settings
        read_ports, write_ports, total_ports = _sanitize_port_option(
            read_ports, write_ports, total_ports
        )

        # Create the forward-backward table
        forward_backward_table = _ForwardBackwardTable(self)

        with open(filename, 'w') as f:
            from b_asic.codegen import vhdl

            vhdl.common.write_b_asic_vhdl_preamble(f)
            vhdl.common.write_ieee_header(f)
            vhdl.entity.write_register_based_storage(
                f, entity_name=entity_name, collection=self, word_length=word_length
            )
            vhdl.architecture.write_register_based_storage(
                f,
                forward_backward_table=forward_backward_table,
                entity_name=entity_name,
                word_length=word_length,
                read_ports=read_ports,
                write_ports=write_ports,
                total_ports=total_ports,
            )