From 669e44a34aed62e7c673ced46f5a24fdbab2409b Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Sat, 15 Mar 2025 18:17:13 +0100
Subject: [PATCH] More unified plots plus minor cleanups

---
 b_asic/architecture.py                 |  5 +++
 b_asic/resources.py                    |  9 ++--
 b_asic/schedule.py                     | 62 ++++++++++++++------------
 b_asic/scheduler_gui/scheduler_item.py |  4 +-
 4 files changed, 48 insertions(+), 32 deletions(-)

diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index d7b41982..bcefb93c 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -2,6 +2,7 @@
 B-ASIC architecture classes.
 """
 
+import math
 from collections import defaultdict
 from collections.abc import Iterable, Iterator
 from io import TextIOWrapper
@@ -262,8 +263,11 @@ class Resource(HardwareBlock):
         """
         fig, ax = plt.subplots(layout="constrained")
         self.plot_content(ax, **kwargs)
+        height = 0.4
         if title:
+            height += 0.4
             fig.suptitle(title)
+        fig.set_figheight(math.floor(max(ax.get_ylim())) * 0.3 + height)
         fig.show()  # type: ignore
 
     @property
@@ -296,6 +300,7 @@ class Resource(HardwareBlock):
         """
         fig, ax = plt.subplots(layout="constrained")
         self.plot_content(ax)
+        fig.set_figheight(math.floor(max(ax.get_ylim())) * 0.3 + 0.4)
         return fig
 
     @property
diff --git a/b_asic/resources.py b/b_asic/resources.py
index 6f979cee..1af2b9b0 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -3,7 +3,7 @@ import re
 from collections import Counter, defaultdict
 from collections.abc import Iterable
 from functools import reduce
-from math import log2
+from math import floor, log2
 from typing import Literal, TypeVar
 
 import matplotlib.pyplot as plt
@@ -111,7 +111,7 @@ def draw_exclusion_graph_coloring(
         import networkx as nx
         import matplotlib.pyplot as plt
 
-        _, ax = plt.subplots()
+        fig, ax = plt.subplots()
         collection = ProcessCollection(...)
         exclusion_graph = collection.create_exclusion_graph_from_ports(
             read_ports=1,
@@ -120,7 +120,7 @@ def draw_exclusion_graph_coloring(
         )
         coloring = nx.greedy_color(exclusion_graph)
         draw_exclusion_graph_coloring(exclusion_graph, coloring, ax=ax)
-        plt.show()
+        fig.show()
 
     Parameters
     ----------
@@ -735,8 +735,11 @@ class ProcessCollection:
             show_markers=show_markers,
             allow_excessive_lifetimes=allow_excessive_lifetimes,
         )
+        height = 0.4
         if title:
+            height = 0.8
             fig.suptitle(title)
+        fig.set_figheight(floor(max(ax.get_ylim())) * 0.3 + height)
         fig.show()  # type: ignore
 
     def create_exclusion_graph_from_ports(
diff --git a/b_asic/schedule.py b/b_asic/schedule.py
index 9f349f0b..e4d6734d 100644
--- a/b_asic/schedule.py
+++ b/b_asic/schedule.py
@@ -737,7 +737,9 @@ class Schedule:
         self, operation_height: float = 1.0, operation_gap: float = OPERATION_GAP
     ):
         max_pos_graph_id = max(self._y_locations, key=self._y_locations.get)
-        return self._get_y_position(max_pos_graph_id, operation_height, operation_gap)
+        return self._get_y_plot_location(
+            max_pos_graph_id, operation_height, operation_gap
+        )
 
     def place_operation(
         self, op: Operation, time: int, op_laps: dict[GraphID, int]
@@ -1053,7 +1055,7 @@ class Schedule:
         """Get a list of all TypeNames used in the Schedule."""
         return self._sfg.get_used_type_names()
 
-    def _get_y_position(
+    def _get_y_plot_location(
         self, graph_id, operation_height=1.0, operation_gap=OPERATION_GAP
     ) -> float:
         y_location = self._y_locations[graph_id]
@@ -1065,12 +1067,17 @@ class Schedule:
             self._y_locations[graph_id] = y_location
         return operation_gap + y_location * (operation_height + operation_gap)
 
-    def sort_y_locations_on_start_times(self):
+    def sort_y_locations_on_start_times(self) -> None:
         """
         Sort the y-locations of the schedule based on start times of the operations.
 
         Inputs, outputs, dontcares, and sinks are located adjacent to the operations that
         they are connected to.
+
+        See Also
+        --------
+        move_y_location
+        set_y_location
         """
         for i, graph_id in enumerate(
             sorted(self._start_times, key=self._start_times.get)
@@ -1172,25 +1179,24 @@ class Schedule:
 
             else:
                 if end[0] == start[0]:
-                    path = Path(
-                        [
-                            start,
-                            [start[0] + SPLINE_OFFSET, start[1]],
-                            [start[0] - SPLINE_OFFSET, end[1]],
-                            end,
-                        ],
-                        [Path.MOVETO] + [Path.CURVE4] * 3,
-                    )
+                    middle_points = [
+                        [start[0] + SPLINE_OFFSET, start[1]],
+                        [start[0] - SPLINE_OFFSET, end[1]],
+                    ]
                 else:
-                    path = Path(
-                        [
-                            start,
-                            [(start[0] + end[0]) / 2, start[1]],
-                            [(start[0] + end[0]) / 2, end[1]],
-                            end,
-                        ],
-                        [Path.MOVETO] + [Path.CURVE4] * 3,
-                    )
+                    middle_points = [
+                        [(start[0] + end[0]) / 2, start[1]],
+                        [(start[0] + end[0]) / 2, end[1]],
+                    ]
+
+                path = Path(
+                    [
+                        start,
+                        *middle_points,
+                        end,
+                    ],
+                    [Path.MOVETO] + [Path.CURVE4] * 3,
+                )
                 path_patch = PathPatch(
                     path,
                     fc='none',
@@ -1221,7 +1227,7 @@ class Schedule:
         ax.set_axisbelow(True)
         ax.grid()
         for graph_id, op_start_time in self._start_times.items():
-            y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
+            y_pos = self._get_y_plot_location(graph_id, operation_gap=operation_gap)
             operation = cast(Operation, self._sfg.find_by_id(graph_id))
             # Rewrite to make better use of NumPy
             (
@@ -1278,14 +1284,16 @@ class Schedule:
         for graph_id, op_start_time in self._start_times.items():
             operation = cast(Operation, self._sfg.find_by_id(graph_id))
             out_coordinates = operation.get_output_coordinates()
-            source_y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
+            source_y_pos = self._get_y_plot_location(
+                graph_id, operation_gap=operation_gap
+            )
 
             for output_port in operation.outputs:
                 for output_signal in output_port.signals:
                     destination = cast(InputPort, output_signal.destination)
                     destination_op = destination.operation
                     destination_start_time = self._start_times[destination_op.graph_id]
-                    destination_y_pos = self._get_y_position(
+                    destination_y_pos = self._get_y_plot_location(
                         destination_op.graph_id, operation_gap=operation_gap
                     )
                     destination_in_coordinates = (
@@ -1306,7 +1314,7 @@ class Schedule:
         # Get operation with maximum position
         max_pos_graph_id = max(self._y_locations, key=self._y_locations.get)
         y_position_max = (
-            self._get_y_position(max_pos_graph_id, operation_gap=operation_gap)
+            self._get_y_plot_location(max_pos_graph_id, operation_gap=operation_gap)
             + 1
             + (OPERATION_GAP if operation_gap is None else operation_gap)
         )
@@ -1384,9 +1392,7 @@ class Schedule:
         Generate an SVG of the schedule. This is automatically displayed in e.g.
         Jupyter Qt console.
         """
-        height = len(self._start_times) * 0.3 + 0.7
-        fig, ax = plt.subplots(figsize=(12, height), layout="constrained")
-        self._plot_schedule(ax)
+        fig = self._get_figure()
         buffer = io.StringIO()
         fig.savefig(buffer, format="svg")
 
diff --git a/b_asic/scheduler_gui/scheduler_item.py b/b_asic/scheduler_gui/scheduler_item.py
index be8d8643..1bae4d31 100644
--- a/b_asic/scheduler_gui/scheduler_item.py
+++ b/b_asic/scheduler_gui/scheduler_item.py
@@ -319,7 +319,9 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup):
         op_item = self._operation_items[graph_id]
         op_item.setPos(
             self._x_axis_indent + self.schedule.start_times[graph_id],
-            self.schedule._get_y_position(graph_id, OPERATION_HEIGHT, OPERATION_GAP),
+            self.schedule._get_y_plot_location(
+                graph_id, OPERATION_HEIGHT, OPERATION_GAP
+            ),
         )
 
     def _redraw_from_start(self) -> None:
-- 
GitLab