Newer
Older
slacks = self._forward_slacks(graph_id)
for outport, signals in slacks.items():
Simon Bjurek
committed
if outport.name.startswith("dontcare"):
continue
for signal, slack in signals.items()
}
ret.append(
MemoryVariable(
(start_time + cast(int, outport.latency_offset))
% self.schedule_time,
)
)
return ret
def get_memory_variables(self) -> ProcessCollection:
"""
Return a ProcessCollection containing all memory variables.
"""
return ProcessCollection(
set(self._get_memory_variables_list()), self.schedule_time
)
def get_operations(self) -> ProcessCollection:
"""
Return a ProcessCollection containing all operations.
Returns
-------
return ProcessCollection(
{
OperatorProcess(
start_time, cast(Operation, self._sfg.find_by_id(graph_id))
)
for graph_id, start_time in self._start_times.items()
if not isinstance(self._sfg.find_by_id(graph_id), DontCare)
and not isinstance(self._sfg.find_by_id(graph_id), Sink)
},
self.schedule_time,
self.cyclic,
)
def get_used_type_names(self) -> list[TypeName]:
"""Get a list of all TypeNames used in the Schedule."""
return self._sfg.get_used_type_names()
def _get_y_position(
self, graph_id, operation_height=1.0, operation_gap=OPERATION_GAP
y_location = self._y_locations[graph_id]
# Assign the lowest row number not yet in use
used = {loc for loc in self._y_locations.values() if loc is not None}
possible = set(range(len(self._start_times))) - used
y_location = min(possible)
self._y_locations[graph_id] = y_location
return operation_gap + y_location * (operation_height + operation_gap)
Simon Bjurek
committed
def sort_y_locations_on_start_times(self):
for i, graph_id in enumerate(
sorted(self._start_times, key=self._start_times.get)
):
self.set_y_location(graph_id, i)
for graph_id in self._start_times:
op = cast(Operation, self._sfg.find_by_id(graph_id))
if isinstance(op, Output):
self.move_y_location(
graph_id,
self.get_y_location(op.preceding_operations[0].graph_id) + 1,
True,
)
if isinstance(op, DontCare):
self.move_y_location(
graph_id,
self.get_y_location(op.subsequent_operations[0].graph_id),
True,
)
Simon Bjurek
committed
def _plot_schedule(self, ax: Axes, operation_gap: float = OPERATION_GAP) -> None:
start: Sequence[float], end: Sequence[float], name: str = "", laps: int = 0
) -> None:
"""Draw an arrow from *start* to *end*."""
if end[0] > self.schedule_time:
end[0] %= self.schedule_time
if start[0] > self.schedule_time:
start[0] %= self.schedule_time
if start not in line_cache:
line = Line2D(
[start[0], self._schedule_time + SCHEDULE_OFFSET],
color=_SIGNAL_COLOR,
lw=SIGNAL_LINEWIDTH,
ax.add_line(line)
ax.text(
self._schedule_time + SCHEDULE_OFFSET,
start[1],
name,
verticalalignment="center",
)
line = Line2D(
[-SCHEDULE_OFFSET, end[0]],
[end[1], end[1]],
color=_SIGNAL_COLOR,
lw=SIGNAL_LINEWIDTH,
verticalalignment="center",
horizontalalignment="right",
)
line_cache.append(start)
if end[0] == start[0]:
path = Path(
[
start,
[start[0] + SPLINE_OFFSET, start[1]],
[start[0] - SPLINE_OFFSET, end[1]],
end,
],
[Path.MOVETO] + [Path.CURVE4] * 3,
)
else:
path = Path(
[
start,
[(start[0] + end[0]) / 2, start[1]],
[(start[0] + end[0]) / 2, end[1]],
end,
],
[Path.MOVETO] + [Path.CURVE4] * 3,
)
path_patch = PathPatch(
path,
fc='none',
ec=_SIGNAL_COLOR,
lw=SIGNAL_LINEWIDTH,
zorder=10,
)
ax.add_patch(path_patch)
def _draw_offset_arrow(
start: Sequence[float],
end: Sequence[float],
start_offset: Sequence[float],
end_offset: Sequence[float],
name: str = "",
laps: int = 0,
) -> None:
"""Draw an arrow from *start* to *end*, but with an offset."""
_draw_arrow(
[start[0] + start_offset[0], start[1] + start_offset[1]],
[end[0] + end_offset[0], end[1] + end_offset[1]],
name=name,
laps=laps,
)
ytickpositions = []
yticklabels = []
ax.set_axisbelow(True)
ax.grid()
for graph_id, op_start_time in self._start_times.items():
y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
operation = cast(Operation, self._sfg.find_by_id(graph_id))
# Rewrite to make better use of NumPy
(
latency_coordinates,
execution_time_coordinates,
) = operation.get_plot_coordinates()
_x, _y = zip(*latency_coordinates)
x = np.array(_x)
y = np.array(_y)
xvalues = x + op_start_time
xy = np.stack((xvalues, y + y_pos))
p = ax.add_patch(Polygon(xy.T, fc=_LATENCY_COLOR))
p.set_clip_box(TransformedBbox(Bbox([[0, 0], [1, 1]]), ax.transAxes))
if any(xvalues > self.schedule_time) and not isinstance(operation, Output):
xy = np.stack((xvalues - self.schedule_time, y + y_pos))
p = ax.add_patch(Polygon(xy.T, fc=_LATENCY_COLOR))
p.set_clip_box(TransformedBbox(Bbox([[0, 0], [1, 1]]), ax.transAxes))
if isinstance(operation, Input):
ax.annotate(
graph_id,
xy=(op_start_time - 0.48, y_pos + 0.7),
color="black",
size=10 - (0.05 * len(self._start_times)),
)
else:
ax.annotate(
graph_id,
xy=(op_start_time + 0.03, y_pos + 0.7),
color="black",
size=10 - (0.05 * len(self._start_times)),
)
if execution_time_coordinates:
_x, _y = zip(*execution_time_coordinates)
x = np.array(_x)
y = np.array(_y)
xvalues = x + op_start_time
xvalues,
color=_EXECUTION_TIME_COLOR,
if any(xvalues > self.schedule_time) and not isinstance(
operation, Output
):
ax.plot(
xvalues - self.schedule_time,
y + y_pos,
color=_EXECUTION_TIME_COLOR,
linewidth=3,
)
ytickpositions.append(y_pos + 0.5)
yticklabels.append(cast(Operation, self._sfg.find_by_id(graph_id)).name)
for graph_id, op_start_time in self._start_times.items():
operation = cast(Operation, self._sfg.find_by_id(graph_id))
out_coordinates = operation.get_output_coordinates()
source_y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
for output_port in operation.outputs:
for output_signal in output_port.signals:
destination = cast(InputPort, output_signal.destination)
destination_op = destination.operation
destination_start_time = self._start_times[destination_op.graph_id]
destination_y_pos = self._get_y_position(
destination_op.graph_id, operation_gap=operation_gap
destination_in_coordinates = (
out_coordinates[output_port.index],
[op_start_time, source_y_pos],
[destination_start_time, destination_y_pos],
name=graph_id,
ax.set_yticks(ytickpositions)
ax.set_yticklabels(yticklabels)
# Get operation with maximum position
max_pos_graph_id = max(self._y_locations, key=self._y_locations.get)
self._get_y_position(max_pos_graph_id, operation_gap=operation_gap)
+ 1
+ (OPERATION_GAP if operation_gap is None else operation_gap)
)
ax.axis([-1, self._schedule_time + 1, y_position_max, 0]) # Inverted y-axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True, min_n_ticks=1))
ax.axvline(
0,
linestyle="--",
color="black",
ax.axvline(
self._schedule_time,
linestyle="--",
color="black",
def reset_y_locations(self) -> None:
"""Reset all the y-locations in the schedule to None"""
self._y_locations = defaultdict(_y_locations_default)
def plot(self, ax: Axes, operation_gap: float = OPERATION_GAP) -> None:
"""
Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass.
Parameters
----------
ax : :class:`~matplotlib.axes.Axes`
The :class:`matplotlib.axes.Axes` to plot in.
operation_gap : float, optional
The vertical distance between operations in the schedule. The height of
the operation is always 1.
self, operation_gap: float = OPERATION_GAP, title: str | None = None
Show the schedule. Will display based on the current Matplotlib backend.
Parameters
----------
operation_gap : float, optional
The vertical distance between operations in the schedule. The height of
the operation is always 1.
fig = self._get_figure(operation_gap=operation_gap)
if title:
fig.suptitle(title)
fig.show()
def _get_figure(self, operation_gap: float = OPERATION_GAP) -> Figure:
"""
Create a Figure and an Axes and plot schedule in the Axes.
Parameters
----------
operation_gap : float, optional
The vertical distance between operations in the schedule. The height of
the operation is always 1.
height = len(self._start_times) * 0.3 + 2
fig, ax = plt.subplots(figsize=(12, height))
self._plot_schedule(ax, operation_gap=operation_gap)
Generate an SVG of the schedule. This is automatically displayed in e.g.
Jupyter Qt console.
height = len(self._start_times) * 0.3 + 2
fig, ax = plt.subplots(figsize=(12, height))
self._plot_schedule(ax)
buffer = io.StringIO()
fig.savefig(buffer, format="svg")
return buffer.getvalue()
# SVG is valid HTML. This is useful for e.g. sphinx-gallery
_repr_html_ = _repr_svg_