Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
Commits on Source (33)
Showing
with 988 additions and 623 deletions
......@@ -9,6 +9,7 @@ import logging
import os
import sys
from pprint import pprint
from typing import Optional, Tuple
from qtpy.QtCore import QFileInfo, QSize, Qt
from qtpy.QtGui import QCursor, QIcon, QKeySequence, QPainter
......@@ -33,13 +34,17 @@ from b_asic.GUI._preferences import GAP, GRID, MINBUTTONSIZE, PORTHEIGHT
from b_asic.GUI.arrow import Arrow
from b_asic.GUI.drag_button import DragButton
from b_asic.GUI.gui_interface import Ui_main_window
from b_asic.GUI.plot_window import PlotWindow
from b_asic.GUI.port_button import PortButton
from b_asic.GUI.select_sfg_window import SelectSFGWindow
from b_asic.GUI.show_pc_window import ShowPCWindow
from b_asic.GUI.simulate_sfg_window import Plot, SimulateSFGWindow
# from b_asic.GUI.simulate_sfg_window import Plot, SimulateSFGWindow
from b_asic.GUI.simulate_sfg_window import SimulateSFGWindow
from b_asic.GUI.util_dialogs import FaqWindow, KeybindsWindow
from b_asic.GUI.utils import decorate_class, handle_error
from b_asic.gui_utils.about_window import AboutWindow
from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort
from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.signal_flow_graph import SFG
......@@ -85,9 +90,7 @@ class MainWindow(QMainWindow):
b_asic.special_operations, self.ui.special_operations_list
)
self.shortcut_core = QShortcut(
QKeySequence("Ctrl+R"), self.ui.operation_box
)
self.shortcut_core = QShortcut(QKeySequence("Ctrl+R"), self.ui.operation_box)
self.shortcut_core.activated.connect(
self._refresh_operations_list_from_namespace
)
......@@ -137,11 +140,11 @@ class MainWindow(QMainWindow):
self.cursor = QCursor()
def init_ui(self):
def init_ui(self) -> None:
self.create_toolbar_view()
self.create_graphics_view()
def create_graphics_view(self):
def create_graphics_view(self) -> None:
self.graphic_view = QGraphicsView(self.scene, self)
self.graphic_view.setRenderHint(QPainter.Antialiasing)
self.graphic_view.setGeometry(
......@@ -149,12 +152,12 @@ class MainWindow(QMainWindow):
)
self.graphic_view.setDragMode(QGraphicsView.RubberBandDrag)
def create_toolbar_view(self):
def create_toolbar_view(self) -> None:
self.toolbar = self.addToolBar("Toolbar")
self.toolbar.addAction("Create SFG", self.create_sfg_from_toolbar)
self.toolbar.addAction("Clear workspace", self.clear_workspace)
def resizeEvent(self, event):
def resizeEvent(self, event) -> None:
self.ui.operation_box.setGeometry(
10, 10, self.ui.operation_box.width(), self.height()
)
......@@ -166,14 +169,14 @@ class MainWindow(QMainWindow):
)
super().resizeEvent(event)
def wheelEvent(self, event):
def wheelEvent(self, event) -> None:
if event.modifiers() == Qt.KeyboardModifier.ControlModifier:
old_zoom = self.zoom
self.zoom += event.angleDelta().y() / 2500
self.graphic_view.scale(self.zoom, self.zoom)
self.zoom = old_zoom
def view_operation_names(self):
def view_operation_names(self) -> None:
if self.check_show_names.isChecked():
self.is_show_names = True
else:
......@@ -183,7 +186,7 @@ class MainWindow(QMainWindow):
operation.label.setOpacity(self.is_show_names)
operation.is_show_name = self.is_show_names
def _save_work(self):
def _save_work(self) -> None:
sfg = self.sfg_widget.sfg
file_dialog = QFileDialog()
file_dialog.setDefaultSuffix(".py")
......@@ -203,14 +206,10 @@ class MainWindow(QMainWindow):
try:
with open(module, "w+") as file_obj:
file_obj.write(
sfg_to_python(
sfg, suffix=f"positions = {operation_positions}"
)
sfg_to_python(sfg, suffix=f"positions = {operation_positions}")
)
except Exception as e:
self.logger.error(
f"Failed to save SFG to path: {module}, with error: {e}."
)
self.logger.error(f"Failed to save SFG to path: {module}, with error: {e}.")
return
self.logger.info("Saved SFG to path: " + str(module))
......@@ -234,8 +233,7 @@ class MainWindow(QMainWindow):
sfg, positions = python_to_sfg(module)
except ImportError as e:
self.logger.error(
f"Failed to load module: {module} with the following error:"
f" {e}."
f"Failed to load module: {module} with the following error: {e}."
)
return
......@@ -263,12 +261,8 @@ class MainWindow(QMainWindow):
# print(op)
self.create_operation(
op,
positions[op.graph_id][0:2]
if op.graph_id in positions
else None,
positions[op.graph_id][-1]
if op.graph_id in positions
else None,
positions[op.graph_id][0:2] if op.graph_id in positions else None,
positions[op.graph_id][-1] if op.graph_id in positions else None,
)
def connect_ports(ports):
......@@ -284,9 +278,7 @@ class MainWindow(QMainWindow):
destination = [
destination
for destination in self.portDict[
self.operationDragDict[
signal.destination.operation
]
self.operationDragDict[signal.destination.operation]
]
if destination.port is signal.destination
]
......@@ -325,7 +317,7 @@ class MainWindow(QMainWindow):
self.scene.clear()
self.logger.info("Workspace cleared.")
def create_sfg_from_toolbar(self):
def create_sfg_from_toolbar(self) -> None:
inputs = []
outputs = []
for op in self.pressed_operations:
......@@ -344,14 +336,10 @@ class MainWindow(QMainWindow):
self.logger.warning("Failed to initialize SFG with empty name.")
return
self.logger.info(
"Creating SFG with name: %s from selected operations." % name
)
self.logger.info("Creating SFG with name: %s from selected operations." % name)
sfg = SFG(inputs=inputs, outputs=outputs, name=name)
self.logger.info(
"Created SFG with name: %s from selected operations." % name
)
self.logger.info("Created SFG with name: %s from selected operations." % name)
def check_equality(signal, signal_2):
if not (
......@@ -369,8 +357,7 @@ class MainWindow(QMainWindow):
and hasattr(signal_2.destination.operation, "value")
):
if not (
signal.source.operation.value
== signal_2.source.operation.value
signal.source.operation.value == signal_2.source.operation.value
and signal.destination.operation.value
== signal_2.destination.operation.value
):
......@@ -383,8 +370,7 @@ class MainWindow(QMainWindow):
and hasattr(signal_2.destination.operation, "name")
):
if not (
signal.source.operation.name
== signal_2.source.operation.name
signal.source.operation.name == signal_2.source.operation.name
and signal.destination.operation.name
== signal_2.destination.operation.name
):
......@@ -453,12 +439,12 @@ class MainWindow(QMainWindow):
self.sfg_dict[sfg.name] = sfg
def _show_precedence_graph(self, event=None):
def _show_precedence_graph(self, event=None) -> None:
self.dialog = ShowPCWindow(self)
self.dialog.add_sfg_to_dialog()
self.dialog.show()
def get_operations_from_namespace(self, namespace):
def get_operations_from_namespace(self, namespace) -> None:
self.logger.info(
"Fetching operations from namespace: " + str(namespace.__name__)
)
......@@ -468,7 +454,7 @@ class MainWindow(QMainWindow):
if hasattr(getattr(namespace, comp), "type_name")
]
def add_operations_from_namespace(self, namespace, _list):
def add_operations_from_namespace(self, namespace, _list) -> None:
for attr_name in self.get_operations_from_namespace(namespace):
attr = getattr(namespace, attr_name)
try:
......@@ -479,11 +465,9 @@ class MainWindow(QMainWindow):
except NotImplementedError:
pass
self.logger.info(
"Added operations from namespace: " + str(namespace.__name__)
)
self.logger.info("Added operations from namespace: " + str(namespace.__name__))
def add_namespace(self, event=None):
def add_namespace(self, event=None) -> None:
module, accepted = QFileDialog().getOpenFileName()
if not accepted:
return
......@@ -494,16 +478,25 @@ class MainWindow(QMainWindow):
namespace = importlib.util.module_from_spec(spec)
spec.loader.exec_module(namespace)
self.add_operations_from_namespace(
namespace, self.ui.custom_operations_list
)
self.add_operations_from_namespace(namespace, self.ui.custom_operations_list)
def create_operation(self, op, position=None, is_flipped: bool = False):
def create_operation(
self,
op: Operation,
position: Optional[Tuple[float, float]] = None,
is_flipped: bool = False,
) -> None:
"""
Parameters
----------
op : Operation
position : (float, float), optional
is_flipped : bool, default: False
"""
try:
if op in self.operationDragDict:
self.logger.warning(
"Multiple instances of operation with same name"
)
self.logger.warning("Multiple instances of operation with same name")
return
attr_button = DragButton(op.graph_id, op, True, window=self)
......@@ -570,7 +563,7 @@ class MainWindow(QMainWindow):
"Unexpected error occurred while creating operation: " + str(e)
)
def _create_operation_item(self, item):
def _create_operation_item(self, item) -> None:
self.logger.info("Creating operation of type: %s" % str(item.text()))
try:
attr_oper = self._operations_from_name[item.text()]()
......@@ -580,7 +573,7 @@ class MainWindow(QMainWindow):
"Unexpected error occurred while creating operation: " + str(e)
)
def _refresh_operations_list_from_namespace(self):
def _refresh_operations_list_from_namespace(self) -> None:
self.logger.info("Refreshing operation list.")
self.ui.core_operations_list.clear()
self.ui.special_operations_list.clear()
......@@ -593,10 +586,10 @@ class MainWindow(QMainWindow):
)
self.logger.info("Finished refreshing operation list.")
def on_list_widget_item_clicked(self, item):
def on_list_widget_item_clicked(self, item) -> None:
self._create_operation_item(item)
def keyPressEvent(self, event):
def keyPressEvent(self, event) -> None:
if event.key() == Qt.Key.Key_Delete:
for pressed_op in self.pressed_operations:
pressed_op.remove()
......@@ -604,11 +597,10 @@ class MainWindow(QMainWindow):
self.pressed_operations.clear()
super().keyPressEvent(event)
def _connect_callback(self, *event):
def _connect_callback(self, *event) -> None:
if len(self.pressed_ports) < 2:
self.logger.warning(
"Cannot connect less than two ports. Please select at least"
" two."
"Cannot connect less than two ports. Please select at least two."
)
return
......@@ -633,9 +625,7 @@ class MainWindow(QMainWindow):
for port in self.pressed_ports:
port.select_port()
def _connect_button(
self, source: PortButton, destination: PortButton
) -> None:
def _connect_button(self, source: PortButton, destination: PortButton) -> None:
"""
Connect two PortButtons with an Arrow.
......@@ -695,28 +685,23 @@ class MainWindow(QMainWindow):
def _simulate_sfg(self):
for sfg, properties in self.dialog.properties.items():
self.logger.info("Simulating SFG with name: %s" % str(sfg.name))
simulation = FastSimulation(
sfg, input_providers=properties["input_values"]
)
simulation = FastSimulation(sfg, input_providers=properties["input_values"])
l_result = simulation.run_for(
properties["iteration_count"],
save_results=properties["all_results"],
)
print(f"{'=' * 10} {sfg.name} {'=' * 10}")
pprint(
simulation.results if properties["all_results"] else l_result
)
pprint(simulation.results if properties["all_results"] else l_result)
print(f"{'=' * 10} /{sfg.name} {'=' * 10}")
if properties["show_plot"]:
self.logger.info(
"Opening plot for SFG with name: " + str(sfg.name)
)
self.logger.info("Opening plot for SFG with name: " + str(sfg.name))
self.logger.info(
"To save the plot press 'Ctrl+S' when the plot is focused."
)
self.plot = Plot(simulation, sfg, self)
# self.plot = Plot(simulation, sfg, self)
self.plot = PlotWindow(simulation.results)
self.plot.show()
def simulate_sfg(self, event=None):
......
# TODO's:
# * Solve the legend update. That isn't working at all.
# * Zoom etc. Might need to change FigureCanvas. Or just something very little.
import re
import sys
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from qtpy.QtCore import Qt
from qtpy.QtGui import QKeySequence
# Intereme imports for the Plot class:
from qtpy.QtWidgets import ( # QFrame,; QScrollArea,; QLineEdit,; QSizePolicy,; QLabel,
QApplication,
QCheckBox,
QDialog,
QFileDialog,
QHBoxLayout,
QListWidget,
QListWidgetItem,
QPushButton,
QShortcut,
QVBoxLayout,
)
class PlotCanvas(FigureCanvas):
"""PlotCanvas is used as a part in the PlotWindow."""
def __init__(self, logger, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
super().__init__(fig)
self.axes = fig.add_subplot(111)
self.axes.xaxis.set_major_locator(MaxNLocator(integer=True))
self.legend = None
self.logger = logger
FigureCanvas.updateGeometry(self)
self.save_figure = QShortcut(QKeySequence("Ctrl+S"), self)
self.save_figure.activated.connect(self._save_plot_figure)
def _save_plot_figure(self):
self.logger.info(f"Saving plot of figure: {self.sfg.name}.")
file_choices = "PNG (*.png)|*.png"
path, ext = QFileDialog.getSaveFileName(
self, "Save file", "", file_choices
)
path = path.encode("utf-8")
if not path[-4:] == file_choices[-4:].encode("utf-8"):
path += file_choices[-4:].encode("utf-8")
if path:
self.print_figure(path.decode(), dpi=self.dpi)
self.logger.info(f"Saved plot: {self.sfg.name} to path: {path}.")
class PlotWindow(QDialog):
"""Dialog for plotting the result of a simulation."""
def __init__(
self,
sim_result,
# sfg_name="{sfg_name}",
# window=None,
logger=print,
parent=None,
# width=5,
# height=4,
# dpi=100,
):
super().__init__(parent=parent)
# self._window = window
self.setWindowFlags(
Qt.WindowTitleHint
| Qt.WindowCloseButtonHint
| Qt.WindowMinimizeButtonHint
| Qt.WindowMaximizeButtonHint
| Qt.WindowStaysOnTopHint
)
self.setWindowTitle("Simulation result")
self.sim_result = sim_result
self._auto_redraw = False
# Categorise sim_results into inputs, outputs, delays, others
sim_res_ins = {}
sim_res_outs = {}
sim_res_delays = {}
sim_res_others = {}
for key in sim_result:
if re.fullmatch("in[0-9]+", key):
sim_res_ins[key] = sim_result[key]
elif re.fullmatch("[0-9]+", key):
sim_res_outs[key] = sim_result[key]
elif re.fullmatch("t[0-9]+", key):
sim_res_delays[key] = sim_result[key]
else:
sim_res_others[key] = sim_result[key]
# Layout: ############################################
# | list | |
# | ... | plot |
# | misc | |
self.dialog_layout = QHBoxLayout()
self.setLayout(self.dialog_layout)
listlayout = QVBoxLayout()
self.plotcanvas = PlotCanvas(
logger=logger, parent=self, width=5, height=4, dpi=100
)
self.dialog_layout.addLayout(listlayout)
self.dialog_layout.addWidget(self.plotcanvas)
########### Plot: ##############
# Do this before the list layout, as the list layout will re/set visibility
# Note: The order is of importens. Interesting lines last, to be on top.
self._lines = {}
for key in (
sim_res_others | sim_res_delays | sim_res_ins | sim_res_outs
):
line = self.plotcanvas.axes.plot(
sim_result[key], visible=False, label=key
)
self._lines[key] = line
self.plotcanvas.legend = self.plotcanvas.axes.legend()
########### List layout: ##############
# Add two buttons for selecting all/none:
hlayout = QHBoxLayout()
button_all = QPushButton("&All")
button_all.clicked.connect(self._button_all_click)
hlayout.addWidget(button_all)
button_none = QPushButton("&None")
button_none.clicked.connect(self._button_none_click)
hlayout.addWidget(button_none)
listlayout.addLayout(hlayout)
# Add the entire list
self.checklist = QListWidget()
self.checklist.itemChanged.connect(self._item_change)
listitems = {}
for key in (
sim_res_ins | sim_res_outs | sim_res_delays | sim_res_others
):
listitem = QListWidgetItem(key)
listitems[key] = listitem
self.checklist.addItem(listitem)
listitem.setCheckState(
Qt.CheckState.Unchecked # CheckState: Qt.CheckState.{Unchecked, PartiallyChecked, Checked}
)
for key in sim_res_outs:
listitems[key].setCheckState(Qt.CheckState.Checked)
self.checklist.setFixedWidth(150)
listlayout.addWidget(self.checklist)
# Add additional checkboxes
self.legend_checkbox = QCheckBox("&Legend")
self.legend_checkbox.stateChanged.connect(self._legend_checkbox_change)
self.legend_checkbox.setCheckState(Qt.CheckState.Checked)
listlayout.addWidget(self.legend_checkbox)
# self.ontop_checkbox = QCheckBox("&On top")
# self.ontop_checkbox.stateChanged.connect(self._ontop_checkbox_change)
# self.ontop_checkbox.setCheckState(Qt.CheckState.Unchecked)
# listlayout.addWidget(self.ontop_checkbox)
# Add "Close" buttons
buttonClose = QPushButton("&Close", self)
buttonClose.clicked.connect(self.close)
listlayout.addWidget(buttonClose)
# Done. Tell the functions below to redraw the canvas when needed.
# self.plotcanvas.draw()
self._auto_redraw = True
def _legend_checkbox_change(self, checkState):
self.plotcanvas.legend.set(
visible=(checkState == Qt.CheckState.Checked)
)
if self._auto_redraw:
if checkState == Qt.CheckState.Checked:
self.plotcanvas.legend = self.plotcanvas.axes.legend()
self.plotcanvas.draw()
# def _ontop_checkbox_change(self, checkState):
# Bugg: It seems the window closes if you change the WindowStaysOnTopHint.
# (Nothing happens if "changing" from False to False or True to True)
# self.setWindowFlag(Qt.WindowStaysOnTopHint, on = (checkState == Qt.CheckState.Checked))
# self.setWindowFlag(Qt.WindowStaysOnTopHint, on = True)
# print("_ontop_checkbox_change")
def _button_all_click(self, event):
self._auto_redraw = False
for x in range(self.checklist.count()):
self.checklist.item(x).setCheckState(Qt.CheckState.Checked)
self._auto_redraw = True
self.plotcanvas.draw()
def _button_none_click(self, event):
self._auto_redraw = False
for x in range(self.checklist.count()):
self.checklist.item(x).setCheckState(Qt.CheckState.Unchecked)
self._auto_redraw = True
self.plotcanvas.draw()
def _item_change(self, listitem):
key = listitem.text()
self._lines[key][0].set(
visible=(listitem.checkState() == Qt.CheckState.Checked)
)
if self._auto_redraw:
if self.legend_checkbox.checkState == Qt.CheckState.Checked:
self.plotcanvas.legend = self.plotcanvas.axes.legend()
self.plotcanvas.draw()
# Simple test of the dialog
if __name__ == "__main__":
app = QApplication(sys.argv)
# sim_res = {"c1": [3, 6, 7], "c2": [4, 5, 5], "bfly1.0": [7, 0, 0], "bfly1.1": [-1, 0, 2], "0": [1, 2, 3]}
sim_res = {
'0': [0.5, 0.5, 0, 0],
'add1': [0.5, 0.5, 0, 0],
'cmul1': [0, 0.5, 0, 0],
'cmul2': [0.5, 0, 0, 0],
'in1': [1, 0, 0, 0],
't1': [0, 1, 0, 0],
}
win = PlotWindow(
# window=None, sim_result=sim_res, sfg_name="hej", logger=print
sim_result=sim_res,
)
win.exec_()
# win.show()
"""
B-ASIC select SFG window.
"""
from typing import TYPE_CHECKING
from qtpy.QtCore import Qt, Signal
from qtpy.QtWidgets import QComboBox, QDialog, QPushButton, QVBoxLayout
if TYPE_CHECKING:
from b_asic.GUI.main_window import MainWindow
class SelectSFGWindow(QDialog):
ok = Signal()
def __init__(self, window):
def __init__(self, window: "MainWindow"):
super().__init__()
self._window = window
self.setWindowFlags(Qt.WindowTitleHint | Qt.WindowCloseButtonHint)
......@@ -23,9 +27,8 @@ class SelectSFGWindow(QDialog):
self.sfg = None
self.setLayout(self.dialog_layout)
self.add_sfgs_to_layout()
def add_sfgs_to_layout(self):
# Add SFGs to layout
for sfg in self._window.sfg_dict:
self.combo_box.addItem(sfg)
......
......@@ -62,6 +62,24 @@ class Constant(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
def __repr__(self) -> str:
return f"Constant({self.value})"
def __str__(self) -> str:
return f"{self.value}"
@property
def is_linear(self) -> bool:
return True
@property
def is_constant(self) -> bool:
return True
class Addition(AbstractOperation):
"""
......@@ -125,6 +143,10 @@ class Addition(AbstractOperation):
def evaluate(self, a, b):
return a + b
@property
def is_linear(self) -> bool:
return True
class Subtraction(AbstractOperation):
"""
......@@ -185,6 +207,10 @@ class Subtraction(AbstractOperation):
def evaluate(self, a, b):
return a - b
@property
def is_linear(self) -> bool:
return True
class AddSub(AbstractOperation):
r"""
......@@ -266,6 +292,10 @@ class AddSub(AbstractOperation):
"""Set if operation is an addition."""
self.set_param("is_add", is_add)
@property
def is_linear(self) -> bool:
return True
class Multiplication(AbstractOperation):
r"""
......@@ -327,6 +357,12 @@ class Multiplication(AbstractOperation):
def evaluate(self, a, b):
return a * b
@property
def is_linear(self) -> bool:
return any(
input.connected_source.operation.is_constant for input in self.inputs
)
class Division(AbstractOperation):
r"""
......@@ -368,6 +404,10 @@ class Division(AbstractOperation):
def evaluate(self, a, b):
return a / b
@property
def is_linear(self) -> bool:
return self.input(1).connected_source.operation.is_constant
class Min(AbstractOperation):
r"""
......@@ -410,9 +450,7 @@ class Min(AbstractOperation):
def evaluate(self, a, b):
if isinstance(a, complex) or isinstance(b, complex):
raise ValueError(
"core_operations.Min does not support complex numbers."
)
raise ValueError("core_operations.Min does not support complex numbers.")
return a if a < b else b
......@@ -457,9 +495,7 @@ class Max(AbstractOperation):
def evaluate(self, a, b):
if isinstance(a, complex) or isinstance(b, complex):
raise ValueError(
"core_operations.Max does not support complex numbers."
)
raise ValueError("core_operations.Max does not support complex numbers.")
return a if a > b else b
......@@ -589,8 +625,7 @@ class ConstantMultiplication(AbstractOperation):
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct a ConstantMultiplication operation with the given value.
"""
"""Construct a ConstantMultiplication operation with the given value."""
super().__init__(
input_count=1,
output_count=1,
......@@ -619,6 +654,10 @@ class ConstantMultiplication(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def is_linear(self) -> bool:
return True
class Butterfly(AbstractOperation):
r"""
......@@ -661,6 +700,10 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
@property
def is_linear(self) -> bool:
return True
class MAD(AbstractOperation):
r"""
......@@ -700,6 +743,13 @@ class MAD(AbstractOperation):
def evaluate(self, a, b, c):
return a * b + c
@property
def is_linear(self) -> bool:
return (
self.input(0).connected_source.operation.is_constant
or self.input(1).connected_source.operation.is_constant
)
class SymmetricTwoportAdaptor(AbstractOperation):
r"""
......@@ -752,6 +802,10 @@ class SymmetricTwoportAdaptor(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def is_linear(self) -> bool:
return True
class Reciprocal(AbstractOperation):
r"""
......
......@@ -18,6 +18,8 @@ from b_asic._version import __version__
class AboutWindow(QDialog):
"""About window."""
def __init__(self, window):
super().__init__()
self._window = window
......@@ -27,16 +29,16 @@ class AboutWindow(QDialog):
self.dialog_layout = QVBoxLayout()
self.setLayout(self.dialog_layout)
self.add_information_to_layout()
self._add_information_to_layout()
def hoverText(self, url):
def _hover_text(self, url):
# self.setWindowTitle(url) # When removing mouse, the title gets "B-ASIC Scheduler". Where does THAT come from?
if url:
QToolTip.showText(QCursor.pos(), url)
else:
QToolTip.hideText()
def add_information_to_layout(self):
def _add_information_to_layout(self):
# |1 Title |2 |
# | License | Logo | <- layout12
# | Version | |
......@@ -60,7 +62,7 @@ class AboutWindow(QDialog):
label1.setTextFormat(Qt.MarkdownText)
label1.setWordWrap(True)
label1.setOpenExternalLinks(True)
label1.linkHovered.connect(self.hoverText)
label1.linkHovered.connect(self._hover_text)
self.logo2 = QLabel(self)
self.logo2.setPixmap(
......@@ -77,7 +79,7 @@ class AboutWindow(QDialog):
""" <a href="https://gitlab.liu.se/da/B-ASIC/-/issues">report issues and suggestions</a>."""
)
label3.setOpenExternalLinks(True)
label3.linkHovered.connect(self.hoverText)
label3.linkHovered.connect(self._hover_text)
button4 = QPushButton()
button4.setText("OK")
......@@ -105,7 +107,8 @@ class AboutWindow(QDialog):
# ONLY FOR DEBUG below
def start_about_window():
def show_about_window():
"""Simply show the about window."""
app = QApplication(sys.argv)
window = AboutWindow(QDialog)
window.show()
......@@ -113,4 +116,4 @@ def start_about_window():
if __name__ == "__main__":
start_about_window()
show_about_window()
......@@ -25,12 +25,7 @@ from typing import (
overload,
)
from b_asic.graph_component import (
AbstractGraphComponent,
GraphComponent,
GraphID,
Name,
)
from b_asic.graph_component import AbstractGraphComponent, GraphComponent, GraphID, Name
from b_asic.port import InputPort, OutputPort, SignalSourceProvider
from b_asic.signal import Signal
from b_asic.types import Num
......@@ -403,9 +398,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
"""
Return a tuple containing coordinates for the two polygons outlining
the latency and execution time of the operation.
......@@ -413,24 +406,6 @@ class Operation(GraphComponent, SignalSourceProvider):
"""
raise NotImplementedError
@abstractmethod
def get_io_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
"""
Return a tuple containing coordinates for inputs and outputs, respectively.
These maps to the polygons and are corresponding to a start time of 0
and height 1.
See also
========
get_input_coordinates
get_output_coordinates
"""
raise NotImplementedError
@abstractmethod
def get_input_coordinates(
self,
......@@ -442,7 +417,6 @@ class Operation(GraphComponent, SignalSourceProvider):
See also
========
get_io_coordinates
get_output_coordinates
"""
raise NotImplementedError
......@@ -459,7 +433,6 @@ class Operation(GraphComponent, SignalSourceProvider):
See also
========
get_input_coordinates
get_io_coordinates
"""
raise NotImplementedError
......@@ -494,6 +467,22 @@ class Operation(GraphComponent, SignalSourceProvider):
def _check_all_latencies_set(self) -> None:
raise NotImplementedError
@property
@abstractmethod
def is_linear(self) -> bool:
"""
Return True if the operation is linear.
"""
raise NotImplementedError
@property
@abstractmethod
def is_constant(self) -> bool:
"""
Return True if the output of the operation is constant.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""
......@@ -512,9 +501,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
input_count: int,
output_count: int,
name: Name = Name(""),
input_sources: Optional[
Sequence[Optional[SignalSourceProvider]]
] = None,
input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None,
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
......@@ -575,9 +562,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
@overload
@abstractmethod
def evaluate(
self, *inputs: Num
) -> List[Num]: # pylint: disable=arguments-differ
def evaluate(self, *inputs: Num) -> List[Num]: # pylint: disable=arguments-differ
...
@abstractmethod
......@@ -601,34 +586,25 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
return Addition(
Constant(src) if isinstance(src, Number) else src, self
)
return Addition(Constant(src) if isinstance(src, Number) else src, self)
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
self, Constant(src) if isinstance(src, Number) else src
)
return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
Constant(src) if isinstance(src, Number) else src, self
)
return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
def __mul__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import (
ConstantMultiplication,
Multiplication,
)
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
......@@ -640,10 +616,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import (
ConstantMultiplication,
Multiplication,
)
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
......@@ -655,9 +628,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(
self, Constant(src) if isinstance(src, Number) else src
)
return Division(self, Constant(src) if isinstance(src, Number) else src)
def __rtruediv__(
self, src: Union[SignalSourceProvider, Num]
......@@ -835,8 +806,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Sequence[Optional[Num]]:
return [
self.current_output(i, delays, prefix)
for i in range(self.output_count)
self.current_output(i, delays, prefix) for i in range(self.output_count)
]
def evaluate_outputs(
......@@ -927,9 +897,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
Operations input ports.
"""
return [
signal.source.operation
for signal in self.input_signals
if signal.source
signal.source.operation for signal in self.input_signals if signal.source
]
@property
......@@ -1008,10 +976,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max(
(
(
cast(int, output.latency_offset)
- cast(int, input.latency_offset)
)
(cast(int, output.latency_offset) - cast(int, input.latency_offset))
for output, input in it.product(self.outputs, self.inputs)
)
)
......@@ -1039,7 +1004,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for inputs"
"Missing latencies for input(s)"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
......@@ -1050,8 +1015,8 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for outputs"
f" {[i for i in latency_offsets if i is not None]}"
"Missing latencies for output(s)"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
return cast(List[int], latency_offsets)
......@@ -1116,9 +1081,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
# Doc-string inherited
return (
self._get_plot_coordinates_for_latency(),
......@@ -1169,28 +1132,34 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited
num_in = self.input_count
return tuple(
(
self.input_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.inputs)),
(1 + 2 * k) / (2 * num_in),
)
for k in range(len(self.inputs))
for k in range(num_in)
)
def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited
num_out = self.output_count
return tuple(
(
self.output_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.outputs)),
(1 + 2 * k) / (2 * num_out),
)
for k in range(len(self.outputs))
for k in range(num_out)
)
def get_io_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
# Doc-string inherited
return self.get_input_coordinates(), self.get_output_coordinates()
@property
def is_linear(self) -> bool:
if self.is_constant:
return True
return False
@property
def is_constant(self) -> bool:
return all(
input.connected_source.operation.is_constant for input in self.inputs
)
......@@ -49,8 +49,7 @@ class Port(ABC):
@latency_offset.setter
@abstractmethod
def latency_offset(self, latency_offset: int) -> None:
"""Set the latency_offset of the port to the integer specified value.
"""
"""Set the latency_offset of the port to the integer specified value."""
raise NotImplementedError
@property
......@@ -94,6 +93,12 @@ class Port(ABC):
"""Removes all connected signals from the Port."""
raise NotImplementedError
@property
@abstractmethod
def name(self) -> str:
"""Return a name consisting of *graph_id* of the related operation and the port number.
"""
class AbstractPort(Port):
"""
......@@ -134,6 +139,10 @@ class AbstractPort(Port):
def latency_offset(self, latency_offset: Optional[int]):
self._latency_offset = latency_offset
@property
def name(self):
return f"{self.operation.graph_id}.{self.index}"
class SignalSourceProvider(ABC):
"""
......@@ -196,13 +205,9 @@ class InputPort(AbstractPort):
Get the output port that is currently connected to this input port,
or None if it is unconnected.
"""
return (
None if self._source_signal is None else self._source_signal.source
)
return None if self._source_signal is None else self._source_signal.source
def connect(
self, src: SignalSourceProvider, name: Name = Name("")
) -> Signal:
def connect(self, src: SignalSourceProvider, name: Name = Name("")) -> Signal:
"""
Connect the provided signal source to this input port by creating a new signal.
Returns the new signal.
......
"""
B-ASIC classes representing resource usage.
"""
"""B-ASIC classes representing resource usage."""
from typing import Dict, Optional, Tuple
......@@ -10,9 +8,10 @@ from b_asic.port import InputPort, OutputPort
class Process:
"""
Object for use in resource allocation. Has a start time and an execution
time. Subclasses will in many cases contain additional information for
resource assignment.
Object for use in resource allocation.
Has a start time and an execution time. Subclasses will in many cases
contain additional information for resource assignment.
Parameters
==========
......@@ -58,6 +57,9 @@ class Process:
def __str__(self) -> str:
return self._name
def __repr__(self) -> str:
return f"Process({self.start_time}, {self.execution_time}, {self.name!r})"
# Static counter for default names
_name_cnt = 0
......@@ -85,8 +87,7 @@ class OperatorProcess(Process):
execution_time = operation.execution_time
if execution_time is None:
raise ValueError(
"Operation {operation!r} does not have an execution time"
" specified!"
"Operation {operation!r} does not have an execution time specified!"
)
super().__init__(
start_time,
......@@ -142,11 +143,19 @@ class MemoryVariable(Process):
def write_port(self) -> OutputPort:
return self._write_port
def __repr__(self) -> str:
reads = {k: v for k, v in zip(self._read_ports, self._life_times)}
return (
f"MemoryVariable({self.start_time}, {self.write_port},"
f" {reads!r}, {self.name!r})"
)
class PlainMemoryVariable(Process):
"""
Object that corresponds to a memory variable which only use numbers for
ports. This can be useful when only a plain memory variable is wanted with
Object that corresponds to a memory variable which only use numbers for ports.
This can be useful when only a plain memory variable is wanted with
no connection to a schedule.
Parameters
......@@ -157,7 +166,8 @@ class PlainMemoryVariable(Process):
Identifier for the source of the memory variable.
reads : {int: int, ...}
Dictionary where the key is the destination identifier and the value
is the time after *write_time* that the memory variable is read.
is the time after *write_time* that the memory variable is read, i.e., the
lifetime of the variable.
name : str, optional
The name of the process.
"""
......@@ -189,3 +199,10 @@ class PlainMemoryVariable(Process):
@property
def write_port(self) -> int:
return self._write_port
def __repr__(self) -> str:
reads = {k: v for k, v in zip(self._read_ports, self._life_times)}
return (
f"PlainMemoryVariable({self.start_time}, {self.write_port},"
f" {reads!r}, {self.name!r})"
)
import io
import re
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
......@@ -6,8 +7,12 @@ import networkx as nx
from matplotlib.axes import Axes
from matplotlib.ticker import MaxNLocator
from b_asic._preferences import LATENCY_COLOR
from b_asic.process import Process
# Default latency coloring RGB tuple
_LATENCY_COLOR = tuple(c / 255 for c in LATENCY_COLOR)
#
# Human-intuitive sorting:
# https://stackoverflow.com/questions/2669059/how-to-sort-alpha-numeric-set-in-python
......@@ -20,9 +25,7 @@ _T = TypeVar('_T')
def _sorted_nicely(to_be_sorted: Iterable[_T]) -> List[_T]:
"""Sort the given iterable in the way that humans expect."""
convert = lambda text: int(text) if text.isdigit() else text
alphanum_key = lambda key: [
convert(c) for c in re.split('([0-9]+)', str(key))
]
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', str(key))]
return sorted(to_be_sorted, key=alphanum_key)
......@@ -30,9 +33,7 @@ def draw_exclusion_graph_coloring(
exclusion_graph: nx.Graph,
color_dict: Dict[Process, int],
ax: Optional[Axes] = None,
color_list: Optional[
Union[List[str], List[Tuple[float, float, float]]]
] = None,
color_list: Optional[Union[List[str], List[Tuple[float, float, float]]]] = None,
):
"""
Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment
......@@ -114,6 +115,13 @@ class ProcessCollection:
self._schedule_time = schedule_time
self._cyclic = cyclic
@property
def collection(self):
return self._collection
def __len__(self):
return len(self.__collection__)
def add_process(self, process: Process):
"""
Add a new process to this process collection.
......@@ -129,6 +137,11 @@ class ProcessCollection:
self,
ax: Optional[Axes] = None,
show_name: bool = True,
bar_color: Union[str, Tuple[float, ...]] = _LATENCY_COLOR,
marker_color: Union[str, Tuple[float, ...]] = "black",
marker_read: str = "X",
marker_write: str = "o",
show_markers: bool = True,
):
"""
Use matplotlib.pyplot to generate a process variable lifetime chart from this process collection.
......@@ -140,6 +153,16 @@ class ProcessCollection:
this method will return a new axes object on return.
show_name : bool, default: True
Show name of all processes in the lifetime chart.
bar_color : color, optional
Bar color in lifetime chart.
marker_color : color, default 'black'
Color for read and write marker.
marker_write : str, default 'x'
Marker at write time in the lifetime chart.
marker_read : str, default 'o'
Marker at read time in the lifetime chart.
show_markers : bool, default True
Show markers at read and write times.
Returns
-------
......@@ -154,9 +177,7 @@ class ProcessCollection:
# Lifetime chart left and right padding
PAD_L, PAD_R = 0.05, 0.05
max_execution_time = max(
process.execution_time for process in self._collection
)
max_execution_time = max(process.execution_time for process in self._collection)
if max_execution_time > self._schedule_time:
# Schedule time needs to be greater than or equal to the maximum process lifetime
raise KeyError(
......@@ -167,42 +188,47 @@ class ProcessCollection:
# Generate the life-time chart
for i, process in enumerate(_sorted_nicely(self._collection)):
bar_start = process.start_time % self._schedule_time
bar_end = process.start_time + process.execution_time
bar_end = (
process.start_time + process.execution_time
) % self._schedule_time
if process.execution_time == 0:
# Process has no execution time, draw a tick
_ax.scatter(x=bar_start, y=i + 1, marker='X', color='blue')
elif bar_end > bar_start:
bar_end
if bar_end == self._schedule_time
else bar_end % self._schedule_time
)
if show_markers:
_ax.scatter(
x=bar_start,
y=i + 1,
marker=marker_write,
color=marker_color,
zorder=10,
)
_ax.scatter(
x=bar_end,
y=i + 1,
marker=marker_read,
color=marker_color,
zorder=10,
)
if bar_end >= bar_start:
_ax.broken_barh(
[(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)],
(i + 0.55, 0.9),
color=bar_color,
)
else: # bar_end < bar_start
if bar_end != 0:
_ax.broken_barh(
[
(
PAD_L + bar_start,
self._schedule_time - bar_start - PAD_L,
)
],
(i + 0.55, 0.9),
)
_ax.broken_barh([(0, bar_end - PAD_R)], (i + 0.55, 0.9))
else:
_ax.broken_barh(
[
(
PAD_L + bar_start,
self._schedule_time
- bar_start
- PAD_L
- PAD_R,
)
],
(i + 0.55, 0.9),
)
_ax.broken_barh(
[
(
PAD_L + bar_start,
self._schedule_time - bar_start - PAD_L,
)
],
(i + 0.55, 0.9),
color=bar_color,
)
_ax.broken_barh(
[(0, bar_end - PAD_R)], (i + 0.55, 0.9), color=bar_color
)
if show_name:
_ax.annotate(
str(process),
......@@ -217,16 +243,84 @@ class ProcessCollection:
_ax.set_ylim(0.25, len(self._collection) + 0.75)
return _ax
def create_exclusion_graph_from_overlap(
self, add_name: bool = True
def create_exclusion_graph_from_ports(
self,
read_ports: Optional[int] = None,
write_ports: Optional[int] = None,
total_ports: Optional[int] = None,
) -> nx.Graph:
"""
Generate exclusion graph based on processes overlapping in time
Create an exclusion graph from a ProcessCollection based on a number of read/write ports
Parameters
----------
add_name : bool, default: True
Add name of all processes as a node attribute in the exclusion graph.
read_ports : int
The number of read ports used when splitting process collection based on memory variable access.
write_ports : int
The number of write ports used when splitting process collection based on memory variable access.
total_ports : int
The total number of ports used when splitting process collection based on memory variable access.
Returns
-------
nx.Graph
"""
if total_ports is None:
if read_ports is None or write_ports is None:
raise ValueError(
"If total_ports is unset, both read_ports and write_ports"
" must be provided."
)
else:
total_ports = read_ports + write_ports
else:
read_ports = total_ports if read_ports is None else read_ports
write_ports = total_ports if write_ports is None else write_ports
# Guard for proper read/write port settings
if read_ports != 1 or write_ports != 1:
raise ValueError(
"Splitting with read and write ports not equal to one with the"
" graph coloring heuristic does not make sense."
)
if total_ports not in (1, 2):
raise ValueError(
"Total ports should be either 1 (non-concurrent reads/writes)"
" or 2 (concurrent read/writes) for graph coloring heuristic."
)
# Create new exclusion graph. Nodes are Processes
exclusion_graph = nx.Graph()
exclusion_graph.add_nodes_from(self._collection)
for node1 in exclusion_graph:
for node2 in exclusion_graph:
if node1 == node2:
continue
else:
node1_stop_time = node1.start_time + node1.execution_time
node2_stop_time = node2.start_time + node2.execution_time
if total_ports == 1:
# Single-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1.start_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
else:
# Dual-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
return exclusion_graph
def create_exclusion_graph_from_execution_time(self) -> nx.Graph:
"""
Generate exclusion graph based on processes overlapping in time
Returns
-------
......@@ -256,7 +350,47 @@ class ProcessCollection:
exclusion_graph.add_edge(process1, process2)
return exclusion_graph
def split(
def split_execution_time(
self, heuristic: str = "graph_color", coloring_strategy: str = "DSATUR"
) -> Set["ProcessCollection"]:
"""
Split a ProcessCollection based on overlapping execution time.
Parameters
----------
heuristic : str, default: 'graph_color'
The heuristic used when splitting based on execution times.
One of: 'graph_color', 'left_edge'.
coloring_strategy: str, default: 'DSATUR'
Node ordering strategy passed to nx.coloring.greedy_color() if the heuristic is set to 'graph_color'. This
parameter is only considered if heuristic is set to graph_color.
One of
* `'largest_first'`
* `'random_sequential'`
* `'smallest_last'`
* `'independent_set'`
* `'connected_sequential_bfs'`
* `'connected_sequential_dfs'`
* `'connected_sequential'` (alias for the previous strategy)
* `'saturation_largest_first'`
* `'DSATUR'` (alias for the saturation_largest_first strategy)
Returns
-------
A set of new ProcessCollection objects with the process splitting.
"""
if heuristic == "graph_color":
exclusion_graph = self.create_exclusion_graph_from_execution_time()
coloring = nx.coloring.greedy_color(
exclusion_graph, strategy=coloring_strategy
)
return self._split_from_graph_coloring(coloring)
elif heuristic == "left_edge":
raise NotImplementedError()
else:
raise ValueError(f"Invalid heuristic '{heuristic}'")
def split_ports(
self,
heuristic: str = "graph_color",
read_ports: Optional[int] = None,
......@@ -286,85 +420,97 @@ class ProcessCollection:
"""
if total_ports is None:
if read_ports is None or write_ports is None:
raise ValueError("inteligent quote")
raise ValueError(
"If total_ports is unset, both read_ports and write_ports"
" must be provided."
)
else:
total_ports = read_ports + write_ports
else:
read_ports = total_ports if read_ports is None else read_ports
write_ports = total_ports if write_ports is None else write_ports
if heuristic == "graph_color":
return self._split_graph_color(
read_ports, write_ports, total_ports
)
return self._split_ports_graph_color(read_ports, write_ports, total_ports)
else:
raise ValueError("Invalid heuristic provided")
raise ValueError("Invalid heuristic provided.")
def _split_graph_color(
self, read_ports: int, write_ports: int, total_ports: int
def _split_ports_graph_color(
self,
read_ports: int,
write_ports: int,
total_ports: int,
coloring_strategy: str = "DSATUR",
) -> Set["ProcessCollection"]:
"""
Parameters
----------
read_ports : int, optional
read_ports : int
The number of read ports used when splitting process collection based on memory variable access.
write_ports : int, optional
write_ports : int
The number of write ports used when splitting process collection based on memory variable access.
total_ports : int, optional
total_ports : int
The total number of ports used when splitting process collection based on memory variable access.
coloring_strategy: str, default: 'DSATUR'
Node ordering strategy passed to nx.coloring.greedy_color()
One of
* `'largest_first'`
* `'random_sequential'`
* `'smallest_last'`
* `'independent_set'`
* `'connected_sequential_bfs'`
* `'connected_sequential_dfs'`
* `'connected_sequential'` (alias for the previous strategy)
* `'saturation_largest_first'`
* `'DSATUR'` (alias for the saturation_largest_first strategy)
"""
if read_ports != 1 or write_ports != 1:
raise ValueError(
"Splitting with read and write ports not equal to one with the"
" graph coloring heuristic does not make sense."
)
if total_ports not in (1, 2):
raise ValueError(
"Total ports should be either 1 (non-concurrent reads/writes)"
" or 2 (concurrent read/writes) for graph coloring heuristic."
)
# Create new exclusion graph. Nodes are Processes
exclusion_graph = nx.Graph()
exclusion_graph.add_nodes_from(self._collection)
exclusion_graph = self.create_exclusion_graph_from_ports(
read_ports, write_ports, total_ports
)
# Add exclusions (arcs) between processes in the exclusion graph
for node1 in exclusion_graph:
for node2 in exclusion_graph:
if node1 == node2:
continue
else:
node1_stop_time = node1.start_time + node1.execution_time
node2_stop_time = node2.start_time + node2.execution_time
if total_ports == 1:
# Single-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1.start_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
else:
# Dual-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
# Perform assignment from coloring and return result
coloring = nx.coloring.greedy_color(exclusion_graph, strategy=coloring_strategy)
return self._split_from_graph_coloring(coloring)
def _split_from_graph_coloring(
self,
coloring: Dict[Process, int],
) -> Set["ProcessCollection"]:
"""
Split :class:`Process` objects into a set of :class:`ProcessesCollection` objects based on a provided graph coloring.
Resulting :class:`ProcessCollection` will have the same schedule time and cyclic propoery as self.
# Perform assignment
coloring = nx.coloring.greedy_color(exclusion_graph)
draw_exclusion_graph_coloring(exclusion_graph, coloring)
# process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1)
process_collection_set_list = [
set() for _ in range(max(coloring.values()) + 1)
]
Parameters
----------
coloring : Dict[Process, int]
Process->int (color) mappings
Returns
-------
A set of new ProcessCollections.
"""
process_collection_set_list = [set() for _ in range(max(coloring.values()) + 1)]
for process, color in coloring.items():
process_collection_set_list[color].add(process)
return {
ProcessCollection(
process_collection_set, self._schedule_time, self._cyclic
)
ProcessCollection(process_collection_set, self._schedule_time, self._cyclic)
for process_collection_set in process_collection_set_list
}
def _repr_svg_(self) -> str:
"""
Generate an SVG_ of the resource collection. This is automatically displayed in e.g.
Jupyter Qt console.
"""
fig, ax = plt.subplots()
self.draw_lifetime_chart(ax, show_markers=False)
f = io.StringIO()
fig.savefig(f, format="svg")
return f.getvalue()
def __repr__(self):
return (
f"ProcessCollection({self._collection}, {self._schedule_time},"
f" {self._cyclic})"
)
......@@ -32,8 +32,9 @@ from b_asic.graph_component import GraphID
from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort
from b_asic.process import MemoryVariable, Process
from b_asic.resources import ProcessCollection
from b_asic.signal_flow_graph import SFG
from b_asic.special_operations import Delay, Output
from b_asic.special_operations import Delay, Input, Output
# Need RGB from 0 to 1
_EXECUTION_TIME_COLOR = tuple(c / 255 for c in EXECUTION_TIME_COLOR)
......@@ -90,9 +91,7 @@ class Schedule:
if schedule_time is None:
self._schedule_time = max_end_time
elif schedule_time < max_end_time:
raise ValueError(
f"Too short schedule time. Minimum is {max_end_time}."
)
raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.")
else:
self._schedule_time = schedule_time
......@@ -101,9 +100,7 @@ class Schedule:
Return the start time of the operation with the specified by *graph_id*.
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
return self._start_times[graph_id]
def get_max_end_time(self) -> int:
......@@ -138,9 +135,7 @@ class Schedule:
slacks
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
slack = sys.maxsize
output_slacks = self._forward_slacks(graph_id)
# Make more pythonic
......@@ -193,9 +188,7 @@ class Schedule:
slacks
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
slack = sys.maxsize
input_slacks = self._backward_slacks(graph_id)
# Make more pythonic
......@@ -204,9 +197,7 @@ class Schedule:
slack = min(slack, signal_slack)
return slack
def _backward_slacks(
self, graph_id: GraphID
) -> Dict[InputPort, Dict[Signal, int]]:
def _backward_slacks(self, graph_id: GraphID) -> Dict[InputPort, Dict[Signal, int]]:
ret = {}
start_time = self._start_times[graph_id]
op = cast(Operation, self._sfg.find_by_id(graph_id))
......@@ -249,9 +240,7 @@ class Schedule:
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
return self.backward_slack(graph_id), self.forward_slack(graph_id)
def print_slacks(self) -> None:
......@@ -308,13 +297,11 @@ class Schedule:
factor : int
The time resolution increment.
"""
self._start_times = {
k: factor * v for k, v in self._start_times.items()
}
self._start_times = {k: factor * v for k, v in self._start_times.items()}
for graph_id in self._start_times:
cast(
Operation, self._sfg.find_by_id(graph_id)
)._increase_time_resolution(factor)
cast(Operation, self._sfg.find_by_id(graph_id))._increase_time_resolution(
factor
)
self._schedule_time *= factor
return self
......@@ -365,13 +352,11 @@ class Schedule:
f"Not possible to decrease resolution with {factor}. Possible"
f" values are {possible_values}"
)
self._start_times = {
k: v // factor for k, v in self._start_times.items()
}
self._start_times = {k: v // factor for k, v in self._start_times.items()}
for graph_id in self._start_times:
cast(
Operation, self._sfg.find_by_id(graph_id)
)._decrease_time_resolution(factor)
cast(Operation, self._sfg.find_by_id(graph_id))._decrease_time_resolution(
factor
)
self._schedule_time = self._schedule_time // factor
return self
......@@ -387,9 +372,7 @@ class Schedule:
The time to move. If positive move forward, if negative move backward.
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
(backward_slack, forward_slack) = self.slacks(graph_id)
if not -backward_slack <= time <= forward_slack:
......@@ -412,15 +395,25 @@ class Schedule:
tmp_prev_available = tmp_usage - new_slack
prev_available = tmp_prev_available % self._schedule_time
laps = new_slack // self._schedule_time
source_op = signal.source.operation
if new_usage < prev_available:
print("Incrementing input laps 1")
laps += 1
if prev_available == 0 and new_usage == 0:
if (
prev_available == 0
and new_usage == 0
and (
tmp_prev_available > 0
or tmp_prev_available == 0
and not isinstance(source_op, Input)
)
):
print("Incrementing input laps 2")
laps += 1
print(
[
"Input",
signal.source.operation,
time,
tmp_start,
signal_slack,
......@@ -475,12 +468,8 @@ class Schedule:
while delay_list:
delay_op = cast(Delay, delay_list[0])
delay_input_id = delay_op.input(0).signals[0].graph_id
delay_output_ids = [
sig.graph_id for sig in delay_op.output(0).signals
]
self._sfg = cast(
SFG, self._sfg.remove_operation(delay_op.graph_id)
)
delay_output_ids = [sig.graph_id for sig in delay_op.output(0).signals]
self._sfg = cast(SFG, self._sfg.remove_operation(delay_op.graph_id))
for output_id in delay_output_ids:
self._laps[output_id] += 1 + self._laps[delay_input_id]
del self._laps[delay_input_id]
......@@ -519,21 +508,16 @@ class Schedule:
for inport in op.inputs:
if len(inport.signals) != 1:
raise ValueError(
"Error in scheduling, dangling input port"
" detected."
"Error in scheduling, dangling input port detected."
)
if inport.signals[0].source is None:
raise ValueError(
"Error in scheduling, signal with no source"
" detected."
"Error in scheduling, signal with no source detected."
)
source_port = inport.signals[0].source
source_end_time = None
if (
source_port.operation.graph_id
in non_schedulable_ops
):
if source_port.operation.graph_id in non_schedulable_ops:
source_end_time = 0
else:
source_op_time = self._start_times[
......@@ -558,12 +542,8 @@ class Schedule:
f" {inport.operation.graph_id} has no"
" latency-offset."
)
op_start_time_from_in = (
source_end_time - inport.latency_offset
)
op_start_time = max(
op_start_time, op_start_time_from_in
)
op_start_time_from_in = source_end_time - inport.latency_offset
op_start_time = max(op_start_time, op_start_time_from_in)
self._start_times[op.graph_id] = op_start_time
for output in self._sfg.find_by_type_name(Output.type_name()):
......@@ -583,8 +563,8 @@ class Schedule:
] + cast(int, source_port.latency_offset)
self._remove_delays()
def _get_memory_variables_list(self) -> List['Process']:
ret: List['Process'] = []
def _get_memory_variables_list(self) -> List['MemoryVariable']:
ret: List['MemoryVariable'] = []
for graph_id, start_time in self._start_times.items():
slacks = self._forward_slacks(graph_id)
for outport, signals in slacks.items():
......@@ -597,10 +577,25 @@ class Schedule:
start_time + cast(int, outport.latency_offset),
outport,
reads,
outport.name,
)
)
return ret
def get_memory_variables(self) -> ProcessCollection:
"""
Return a :class:`~b_asic.resources.ProcessCollection` containing all
memory variables.
Returns
-------
ProcessCollection
"""
return ProcessCollection(
set(self._get_memory_variables_list()), self.schedule_time
)
def _get_y_position(
self, graph_id, operation_height=1.0, operation_gap=None
) -> float:
......@@ -609,17 +604,13 @@ class Schedule:
y_location = self._y_locations[graph_id]
if y_location is None:
# Assign the lowest row number not yet in use
used = set(
loc for loc in self._y_locations.values() if loc is not None
)
used = set(loc for loc in self._y_locations.values() if loc is not None)
possible = set(range(len(self._start_times))) - used
y_location = min(possible)
self._y_locations[graph_id] = y_location
return operation_gap + y_location * (operation_height + operation_gap)
def _plot_schedule(
self, ax: Axes, operation_gap: Optional[float] = None
) -> None:
def _plot_schedule(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
"""Draw the schedule."""
line_cache = []
......@@ -706,9 +697,7 @@ class Schedule:
)
ax.add_patch(pp)
def _draw_offset_arrow(
start, end, start_offset, end_offset, name="", laps=0
):
def _draw_offset_arrow(start, end, start_offset, end_offset, name="", laps=0):
"""Draw an arrow from *start* to *end*, but with an offset."""
_draw_arrow(
[start[0] + start_offset[0], start[1] + start_offset[1]],
......@@ -745,24 +734,18 @@ class Schedule:
linewidth=3,
)
ytickpositions.append(y_pos + 0.5)
yticklabels.append(
cast(Operation, self._sfg.find_by_id(graph_id)).name
)
yticklabels.append(cast(Operation, self._sfg.find_by_id(graph_id)).name)
for graph_id, op_start_time in self._start_times.items():
op = cast(Operation, self._sfg.find_by_id(graph_id))
out_coordinates = op.get_output_coordinates()
source_y_pos = self._get_y_position(
graph_id, operation_gap=operation_gap
)
source_y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
for output_port in op.outputs:
for output_signal in output_port.signals:
destination = cast(InputPort, output_signal.destination)
destination_op = destination.operation
destination_start_time = self._start_times[
destination_op.graph_id
]
destination_start_time = self._start_times[destination_op.graph_id]
destination_y_pos = self._get_y_position(
destination_op.graph_id, operation_gap=operation_gap
)
......@@ -788,9 +771,7 @@ class Schedule:
+ 1
+ (OPERATION_GAP if operation_gap is None else operation_gap)
)
ax.axis(
[-1, self._schedule_time + 1, y_position_max, 0]
) # Inverted y-axis
ax.axis([-1, self._schedule_time + 1, y_position_max, 0]) # Inverted y-axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.axvline(
0,
......@@ -807,9 +788,7 @@ class Schedule:
"""Reset all the y-locations in the schedule to None"""
self._y_locations = self._y_locations = defaultdict(lambda: None)
def plot_in_axes(
self, ax: Axes, operation_gap: Optional[float] = None
) -> None:
def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
"""
Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass.
......@@ -821,6 +800,7 @@ class Schedule:
The vertical distance between operations in the schedule. The height of
the operation is always 1.
"""
self._plot_schedule(ax, operation_gap=operation_gap)
def plot(self, operation_gap: Optional[float] = None) -> None:
"""
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""B-ASIC Scheduler-gui Graphics Axes Item Module.
"""
B-ASIC Scheduler-gui Axes Item Module.
Contains the scheduler-gui AxesItem class for drawing and maintain the
Contains the scheduler-gui AxesItem class for drawing and maintaining the
axes in a graph.
"""
from math import pi, sin
......@@ -25,18 +26,19 @@ from b_asic.scheduler_gui.timeline_item import TimelineItem
class AxesItem(QGraphicsItemGroup):
"""
f"""
A class to represent axes in a graph.
Parameters
----------
width
height
width_indent
height_indent
width_padding
height_padding
parent
width_indent : float, default: {SCHEDULE_INDENT}
height_indent : float, default: {SCHEDULE_INDENT}
width_padding : float, default: 0.6
height_padding : float, default: 0.5
parent : QGraphicsItem, optional
Passed to QGraphicsItemGroup.
"""
_scale: float = 1.0
......@@ -72,17 +74,14 @@ class AxesItem(QGraphicsItemGroup):
):
"""
Class for an AxesItem.
*parent* is passed to QGraphicsItemGroup's constructor.
"""
super().__init__(parent=parent)
if width < 0:
raise ValueError(
f"'width' greater or equal to 0 expected, got: {width}."
)
raise ValueError(f"'width' greater or equal to 0 expected, got: {width}.")
if height < 0:
raise ValueError(
f"'height' greater or equal to 0 expected, got: {height}."
)
raise ValueError(f"'height' greater or equal to 0 expected, got: {height}.")
self._width = width
self._height = height
......@@ -170,18 +169,14 @@ class AxesItem(QGraphicsItemGroup):
def set_height(self, height: float) -> None:
# TODO: docstring
if height < 0:
raise ValueError(
f"'height' greater or equal to 0 expected, got: {height}."
)
raise ValueError(f"'height' greater or equal to 0 expected, got: {height}.")
self._height = height
self._update_yaxis()
def set_width(self, width: int) -> None:
# TODO: docstring
if width < 0:
raise ValueError(
f"'width' greater or equal to 0 expected, got: {width}."
)
raise ValueError(f"'width' greater or equal to 0 expected, got: {width}.")
delta_width = width - self._width
......@@ -293,12 +288,7 @@ class AxesItem(QGraphicsItemGroup):
0,
0,
0,
-(
self._height_indent
+ self._height
+ self._height_padding
+ 0.05
),
-(self._height_indent + self._height + self._height_padding + 0.05),
)
self._y_axis.setPen(self._base_pen)
......@@ -325,15 +315,11 @@ class AxesItem(QGraphicsItemGroup):
self._x_label.setScale(1 / self._scale)
x_pos = self._width_indent + 0 + self._width_padding # end of x-axis
x_pos += (
self.mapRectFromItem(
self._x_arrow, self._x_arrow.boundingRect()
).width()
self.mapRectFromItem(self._x_arrow, self._x_arrow.boundingRect()).width()
/ 2
) # + half arrow width
x_pos -= (
self.mapRectFromItem(
self._x_label, self._x_label.boundingRect()
).width()
self.mapRectFromItem(self._x_label, self._x_label.boundingRect()).width()
/ 2
) # - center of label
self._x_label.setPos(x_pos, self._x_label_offset)
......@@ -344,9 +330,7 @@ class AxesItem(QGraphicsItemGroup):
for _ in range(self._width):
self._append_x_tick()
pos = self._x_ledger[-1].pos()
self._x_ledger[-1].setPos(
pos + QPointF(self._width, 0)
) # move timeline
self._x_ledger[-1].setPos(pos + QPointF(self._width, 0)) # move timeline
# y-axis
self._update_yaxis()
......
......@@ -12,6 +12,7 @@ import os
import shutil
import subprocess
import sys
from pathlib import Path
from qtpy import uic
from setuptools_scm import get_version
......@@ -31,31 +32,34 @@ def _check_filenames(*filenames: str) -> None:
exception.
"""
for filename in filenames:
if not os.path.exists(filename):
raise FileNotFoundError(filename)
Path(filename).resolve(strict=True)
def _check_qt_version() -> None:
"""
Check if PySide2 or PyQt5 is installed, otherwise raise AssertionError
Check if PySide2, PyQt5, PySide6, or PyQt6 is installed, otherwise raise AssertionError
exception.
"""
assert uic.PYSIDE2 or uic.PYQT5, "PySide2 or PyQt5 need to be installed"
assert (
uic.PYSIDE2 or uic.PYQT5 or uic.PYSIDE6 or uic.PYQT6
), "Python QT bindings must be installed"
def replace_qt_bindings(filename: str) -> None:
"""Raplaces qt-binding api in 'filename' from PySide2/PyQt5 to qtpy."""
"""Raplaces qt-binding api in *filename* from PySide2/6 or PyQt5/6 to qtpy."""
with open(f"{filename}", "r") as file:
filedata = file.read()
filedata = filedata.replace("from PyQt5", "from qtpy")
filedata = filedata.replace("from PySide2", "from qtpy")
filedata = filedata.replace("from PyQt6", "from qtpy")
filedata = filedata.replace("from PySide6", "from qtpy")
with open(f"{filename}", "w") as file:
file.write(filedata)
def compile_rc(*filenames: str) -> None:
"""
Compile resource file(s) given by 'filenames'. If no arguments are given,
Compile resource file(s) given by *filenames*. If no arguments are given,
the compiler will search for resource (.qrc) files and compile accordingly.
"""
_check_qt_version()
......@@ -70,10 +74,9 @@ def compile_rc(*filenames: str) -> None:
if rcc is None:
rcc = shutil.which("pyrcc5")
arguments = f"-o {outfile} {filename}"
assert rcc, (
"Qt Resource compiler failed, cannot find pyside2-rcc, rcc, or"
" pyrcc5"
)
assert (
rcc
), "Qt Resource compiler failed, cannot find pyside2-rcc, rcc, or pyrcc5"
os_ = sys.platform
if os_.startswith("linux"): # Linux
......@@ -124,7 +127,7 @@ def compile_rc(*filenames: str) -> None:
def compile_ui(*filenames: str) -> None:
"""
Compile form file(s) given by 'filenames'. If no arguments are given, the
Compile form file(s) given by *filenames*. If no arguments are given, the
compiler will search for form (.ui) files and compile accordingly.
"""
_check_qt_version()
......@@ -168,11 +171,43 @@ def compile_ui(*filenames: str) -> None:
log.error(f"{os_} UI compiler not supported")
raise NotImplementedError
else: # uic.PYQT5
elif uic.PYQT5 or uic.PYQT6:
from qtpy.uic import compileUi
with open(outfile, "w") as ofile:
compileUi(filename, ofile)
elif uic.PYQT6:
uic_ = shutil.which("pyside6-uic")
arguments = f"-g python -o {outfile} {filename}"
if uic_ is None:
uic_ = shutil.which("uic")
if uic_ is None:
uic_ = shutil.which("pyuic6")
arguments = f"-o {outfile} {filename}"
assert uic_, (
"Qt User Interface Compiler failed, cannot find pyside6-uic,"
" uic, or pyuic6"
)
os_ = sys.platform
if os_.startswith("linux"): # Linux
cmd = f"{uic_} {arguments}"
subprocess.call(cmd.split())
elif os_.startswith("win32"): # Windows
# TODO: implement
log.error("Windows UI compiler not implemented")
raise NotImplementedError
elif os_.startswith("darwin"): # macOS
# TODO: implement
log.error("macOS UI compiler not implemented")
raise NotImplementedError
else: # other OS
log.error(f"{os_} UI compiler not supported")
raise NotImplementedError
replace_qt_bindings(outfile) # replace qt-bindings with qtpy
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""B-ASIC Scheduler-gui Logger Module.
"""
B-ASIC Scheduler-gui Logger Module.
Contains a logger that logs to the console and a file using levels. It is based
on the :mod:`logging` module and has predefined levels of logging.
......@@ -55,9 +56,7 @@ from types import TracebackType
from typing import Type, Union
def getLogger(
filename: str = "scheduler-gui.log", loglevel: str = "INFO"
) -> Logger:
def getLogger(filename: str = "scheduler-gui.log", loglevel: str = "INFO") -> Logger:
"""
This function creates console- and filehandler and from those, creates a logger
object.
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
B-ASIC Scheduler-gui Module.
B-ASIC Scheduler-GUI Module.
Contains the scheduler-gui MainWindow class for scheduling operations in an SFG.
Contains the scheduler_gui MainWindow class for scheduling operations in an SFG.
Start main-window with start_gui().
Start main-window with ``start_gui()``.
"""
import inspect
import os
......@@ -104,7 +104,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
_zoom: float
def __init__(self):
"""Initialize Scheduler-gui."""
"""Initialize Scheduler-GUI."""
super().__init__()
self._schedule = None
self._graph = None
......@@ -121,9 +121,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
"""Initialize the ui"""
# Connect signals to slots
self.menu_load_from_file.triggered.connect(
self._load_schedule_from_pyfile
)
self.menu_load_from_file.triggered.connect(self._load_schedule_from_pyfile)
self.menu_close_schedule.triggered.connect(self.close_schedule)
self.menu_save.triggered.connect(self.save)
self.menu_save_as.triggered.connect(self.save_as)
......@@ -153,9 +151,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
def _init_graphics(self) -> None:
"""Initialize the QGraphics framework"""
self._scene = QGraphicsScene()
self._scene.addRect(
0, 0, 0, 0
) # dummy rect to be able to setPos() graph
self._scene.addRect(0, 0, 0, 0) # dummy rect to be able to setPos() graph
self.view.setScene(self._scene)
self.view.scale(self._scale, self._scale)
OperationItem._scale = self._scale
......@@ -182,12 +178,12 @@ class MainWindow(QMainWindow, Ui_MainWindow):
@Slot()
def _open_documentation(self) -> None:
"""Callback to open documentation web page."""
webbrowser.open_new_tab("https://da.gitlab-pages.liu.se/B-ASIC/")
@Slot()
def _actionReorder(self) -> None:
"""Callback to reorder all operations vertically based on start time.
"""
"""Callback to reorder all operations vertically based on start time."""
if self.schedule is None:
return
if self._graph is not None:
......@@ -228,25 +224,17 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.tr("Python Files (*.py *.py3)"),
)
if (
not abs_path_filename
): # return if empty filename (QFileDialog was canceled)
if not abs_path_filename: # return if empty filename (QFileDialog was canceled)
return
log.debug("abs_path_filename = {}.".format(abs_path_filename))
module_name = inspect.getmodulename(abs_path_filename)
if not module_name: # return if empty module name
log.error(
"Could not load module from file '{}'.".format(
abs_path_filename
)
)
log.error("Could not load module from file '{}'.".format(abs_path_filename))
return
try:
module = SourceFileLoader(
module_name, abs_path_filename
).load_module()
module = SourceFileLoader(module_name, abs_path_filename).load_module()
except Exception as e:
log.exception(
"Exception occurred. Could not load module from file"
......@@ -262,9 +250,9 @@ class MainWindow(QMainWindow, Ui_MainWindow):
QMessageBox.warning(
self,
self.tr("File not found"),
self.tr(
"Cannot find any Schedule object in file '{}'."
).format(os.path.basename(abs_path_filename)),
self.tr("Cannot find any Schedule object in file '{}'.").format(
os.path.basename(abs_path_filename)
),
)
log.info(
"Cannot find any Schedule object in file '{}'.".format(
......@@ -302,8 +290,9 @@ class MainWindow(QMainWindow, Ui_MainWindow):
@Slot()
def close_schedule(self) -> None:
"""
Close current schedule.
SLOT() for SIGNAL(menu_close_schedule.triggered)
Closes current schedule.
"""
if self._graph:
self._graph._signals.component_selected.disconnect(
......@@ -324,8 +313,9 @@ class MainWindow(QMainWindow, Ui_MainWindow):
@Slot()
def save(self) -> None:
"""
Save current schedule.
SLOT() for SIGNAL(menu_save.triggered)
This method save a schedule.
"""
# TODO: all
self._print_button_pressed("save_schedule()")
......@@ -334,19 +324,22 @@ class MainWindow(QMainWindow, Ui_MainWindow):
@Slot()
def save_as(self) -> None:
"""
Save current schedule asking for file name.
SLOT() for SIGNAL(menu_save_as.triggered)
This method save as a schedule.
"""
# TODO: all
# TODO: Implement
self._print_button_pressed("save_schedule()")
self.update_statusbar(self.tr("Schedule saved successfully"))
@Slot(bool)
def show_info_table(self, checked: bool) -> None:
"""
Show or hide the info table.
SLOT(bool) for SIGNAL(menu_node_info.triggered)
Takes in a boolean and hide or show the info table accordingly with
'checked'.
*checked*.
"""
# Note: splitter handler index 0 is a hidden splitter handle far most left,
# use index 1
......@@ -406,9 +399,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
Updates the 'Schedule' part of the info table.
"""
if self.schedule is not None:
self.info_table.item(1, 1).setText(
str(self.schedule.schedule_time)
)
self.info_table.item(1, 1).setText(str(self.schedule.schedule_time))
@Slot(QRectF)
def shrink_scene_to_min_size(self, rect: QRectF) -> None:
......@@ -454,9 +445,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
if ret == QMessageBox.StandardButton.Yes:
if not hide_dialog:
settings.setValue(
"scheduler/hide_exit_dialog", checkbox.isChecked()
)
settings.setValue("scheduler/hide_exit_dialog", checkbox.isChecked())
self._write_settings()
log.info("Exit: {}".format(os.path.basename(__file__)))
event.accept()
......@@ -489,9 +478,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self._graph._signals.component_selected.connect(
self.info_table_update_component
)
self._graph._signals.component_moved.connect(
self.info_table_update_component
)
self._graph._signals.component_moved.connect(self.info_table_update_component)
self._graph._signals.schedule_time_changed.connect(
self.info_table_update_schedule
)
......@@ -520,12 +507,8 @@ class MainWindow(QMainWindow, Ui_MainWindow):
settings.setValue(
"scheduler/state", self.saveState()
) # toolbars, dockwidgets: pos, size
settings.setValue(
"scheduler/menu/node_info", self.menu_node_info.isChecked()
)
settings.setValue(
"scheduler/splitter/state", self.splitter.saveState()
)
settings.setValue("scheduler/menu/node_info", self.menu_node_info.isChecked())
settings.setValue("scheduler/splitter/state", self.splitter.saveState())
settings.setValue("scheduler/splitter/pos", self.splitter.sizes()[1])
if settings.isWritable():
......@@ -536,9 +519,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
def _read_settings(self) -> None:
"""Read settings from Settings to MainWindow."""
settings = QSettings()
if settings.value(
"scheduler/maximized", defaultValue=False, type=bool
):
if settings.value("scheduler/maximized", defaultValue=False, type=bool):
self.showMaximized()
else:
self.move(settings.value("scheduler/pos", self.pos()))
......@@ -566,9 +547,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.info_table.insertRow(1)
self.info_table.setItem(1, 0, QTableWidgetItem("Schedule Time"))
self.info_table.setItem(2, 0, QTableWidgetItem("Cyclic"))
self.info_table.setItem(
1, 1, QTableWidgetItem(str(schedule.schedule_time))
)
self.info_table.setItem(1, 1, QTableWidgetItem(str(schedule.schedule_time)))
self.info_table.setItem(2, 1, QTableWidgetItem(str(schedule.cyclic)))
def _info_table_fill_component(self, graph_id: GraphID) -> None:
......@@ -630,9 +609,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
for _ in range(3):
self.info_table.removeRow(1)
else:
log.error(
"'Operator' not found in info table. It may have been renamed."
)
log.error("'Operator' not found in info table. It may have been renamed.")
def exit_app(self) -> None:
"""Exit application."""
......@@ -647,9 +624,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
for _ in range(self.info_table.rowCount() - row + 1):
self.info_table.removeRow(row + 1)
else:
log.error(
"'Operator' not found in info table. It may have been renamed."
)
log.error("'Operator' not found in info table. It may have been renamed.")
def start_gui() -> None:
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
B-ASIC Scheduler-gui Graphics Component Item Module.
B-ASIC Scheduler-GUI Operation Item Module.
Contains the scheduler-gui OperationItem class for drawing and maintain a component
in a graph.
Contains the scheduler_gui OperationItem class for drawing and maintain an operation
in the schedule.
"""
from typing import TYPE_CHECKING, Dict, List, Union, cast
......@@ -34,23 +34,24 @@ if TYPE_CHECKING:
class OperationItem(QGraphicsItemGroup):
"""
f"""
Class to represent an operation in a graph.
Parameters
----------
operation : :class:`~b_asic.operation.Operation`
The operation.
parent : :class:`~b_asic.scheduler_gui.scheduler_item.SchedulerItem`
height : float, default: 1.0
Parent passed to QGraphicsItemGroup
height : float, default: {OPERATION_HEIGHT}
The height of the operation.
"""
_scale: float = 1.0
"""Static, changed from MainWindow."""
_operation: Operation
_height: float
_ports: Dict[
str, Dict[str, Union[float, QPointF]]
] # ['port-id']['latency/pos']
_ports: Dict[str, Dict[str, Union[float, QPointF]]] # ['port-id']['latency/pos']
_end_time: int
_latency_item: QGraphicsPathItem
_execution_time_item: QGraphicsPathItem
......@@ -72,9 +73,7 @@ class OperationItem(QGraphicsItemGroup):
self._height = height
operation._check_all_latencies_set()
latency_offsets = cast(Dict[str, int], operation.latency_offsets)
self._ports = {
k: {"latency": float(v)} for k, v in latency_offsets.items()
}
self._ports = {k: {"latency": float(v)} for k, v in latency_offsets.items()}
self._end_time = max(latency_offsets.values())
self._port_items = []
......@@ -122,6 +121,14 @@ class OperationItem(QGraphicsItemGroup):
@height.setter
def height(self, height: float) -> None:
"""
Set height.
Parameters
----------
height : float
The new height.
"""
if self._height != height:
self.clear()
self._height = height
......@@ -168,9 +175,7 @@ class OperationItem(QGraphicsItemGroup):
def _make_component(self) -> None:
"""Makes a new component out of the stored attributes."""
latency_outline_pen = QPen(
Qt.GlobalColor.black
) # used by component outline
latency_outline_pen = QPen(Qt.GlobalColor.black) # used by component outline
latency_outline_pen.setWidthF(2 / self._scale)
# latency_outline_pen.setCapStyle(Qt.RoundCap)
# Qt.FlatCap, Qt.SquareCap (default), Qt.RoundCap
......@@ -178,9 +183,7 @@ class OperationItem(QGraphicsItemGroup):
Qt.RoundJoin
) # Qt.MiterJoin, Qt.BevelJoin (default), Qt.RoundJoin, Qt.SvgMiterJoin
port_filling_brush = QBrush(
Qt.GlobalColor.black
) # used by port filling
port_filling_brush = QBrush(Qt.GlobalColor.black) # used by port filling
port_outline_pen = QPen(Qt.GlobalColor.black) # used by port outline
port_outline_pen.setWidthF(0)
# port_outline_pen.setCosmetic(True)
......@@ -214,11 +217,7 @@ class OperationItem(QGraphicsItemGroup):
self._execution_time_item.setPen(execution_time_pen)
# component item
self._set_background(
OPERATION_LATENCY_INACTIVE
) # used by component filling
inputs, outputs = self._operation.get_io_coordinates()
self._set_background(OPERATION_LATENCY_INACTIVE) # used by component filling
def create_ports(io_coordinates, prefix):
for i, (x, y) in enumerate(io_coordinates):
......@@ -234,8 +233,8 @@ class OperationItem(QGraphicsItemGroup):
new_port.setPos(port_pos.x(), port_pos.y())
self._port_items.append(new_port)
create_ports(inputs, "in")
create_ports(outputs, "out")
create_ports(self._operation.get_input_coordinates(), "in")
create_ports(self._operation.get_output_coordinates(), "out")
# op-id/label
self._label_item = QGraphicsSimpleTextItem(self._operation.graph_id)
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
B-ASIC Scheduler-gui Graphics Graph Event Module.
B-ASIC Scheduler-GUI Graphics Scheduler Event Module.
Contains the scheduler-gui SchedulerEvent class containing event filters and
Contains the scheduler_ui SchedulerEvent class containing event filters and
handlers for SchedulerItem objects.
"""
import math
from typing import List, Optional, overload
# QGraphics and QPainter imports
......@@ -61,7 +61,7 @@ class SchedulerEvent: # PyQt5
def is_valid_delta_time(self, delta_time: int) -> bool:
raise NotImplementedError
def set_schedule_time(self, delta_time: int) -> None:
def change_schedule_time(self, delta_time: int) -> None:
raise NotImplementedError
def set_item_active(self, item: OperationItem) -> None:
......@@ -78,9 +78,7 @@ class SchedulerEvent: # PyQt5
...
@overload
def installSceneEventFilters(
self, filterItems: List[QGraphicsItem]
) -> None:
def installSceneEventFilters(self, filterItems: List[QGraphicsItem]) -> None:
...
def installSceneEventFilters(self, filterItems) -> None:
......@@ -98,9 +96,7 @@ class SchedulerEvent: # PyQt5
...
@overload
def removeSceneEventFilters(
self, filterItems: List[QGraphicsItem]
) -> None:
def removeSceneEventFilters(self, filterItems: List[QGraphicsItem]) -> None:
...
def removeSceneEventFilters(self, filterItems) -> None:
......@@ -166,47 +162,31 @@ class SchedulerEvent: # PyQt5
def operation_focusInEvent(self, event: QFocusEvent) -> None:
...
def operation_contextMenuEvent(
self, event: QGraphicsSceneContextMenuEvent
) -> None:
def operation_contextMenuEvent(self, event: QGraphicsSceneContextMenuEvent) -> None:
...
def operation_dragEnterEvent(
self, event: QGraphicsSceneDragDropEvent
) -> None:
def operation_dragEnterEvent(self, event: QGraphicsSceneDragDropEvent) -> None:
...
def operation_dragMoveEvent(
self, event: QGraphicsSceneDragDropEvent
) -> None:
def operation_dragMoveEvent(self, event: QGraphicsSceneDragDropEvent) -> None:
...
def operation_dragLeaveEvent(
self, event: QGraphicsSceneDragDropEvent
) -> None:
def operation_dragLeaveEvent(self, event: QGraphicsSceneDragDropEvent) -> None:
...
def operation_dropEvent(self, event: QGraphicsSceneDragDropEvent) -> None:
...
def operation_hoverEnterEvent(
self, event: QGraphicsSceneHoverEvent
) -> None:
def operation_hoverEnterEvent(self, event: QGraphicsSceneHoverEvent) -> None:
...
def operation_hoverMoveEvent(
self, event: QGraphicsSceneHoverEvent
) -> None:
def operation_hoverMoveEvent(self, event: QGraphicsSceneHoverEvent) -> None:
...
def operation_hoverLeaveEvent(
self, event: QGraphicsSceneHoverEvent
) -> None:
def operation_hoverLeaveEvent(self, event: QGraphicsSceneHoverEvent) -> None:
...
def operation_mouseMoveEvent(
self, event: QGraphicsSceneMouseEvent
) -> None:
def operation_mouseMoveEvent(self, event: QGraphicsSceneMouseEvent) -> None:
"""
Set the position of the graphical element in the graphic scene,
translate coordinates of the cursor within the graphic element in the
......@@ -215,24 +195,20 @@ class SchedulerEvent: # PyQt5
def update_pos(operation_item, dx, dy):
pos_x = operation_item.x() + dx
pos_y = operation_item.y() + dy * (
OPERATION_GAP + OPERATION_HEIGHT
)
if self.is_component_valid_pos(operation_item, pos_x):
pos_y = operation_item.y() + dy * (OPERATION_GAP + OPERATION_HEIGHT)
operation_item.setX(pos_x)
operation_item.setY(pos_y)
self._current_pos.setX(self._current_pos.x() + dx)
self._current_pos.setY(self._current_pos.y() + dy)
self._redraw_lines(operation_item)
self._schedule._y_locations[
operation_item.operation.graph_id
] += dy
self._schedule._y_locations[operation_item.operation.graph_id] += dy
item: OperationItem = self.scene().mouseGrabberItem()
delta_x = (item.mapToParent(event.pos()) - self._current_pos).x()
delta_y = (item.mapToParent(event.pos()) - self._current_pos).y()
delta_y_steps = round(delta_y / (OPERATION_GAP + OPERATION_HEIGHT))
delta_y_steps = round(2 * delta_y / (OPERATION_GAP + OPERATION_HEIGHT)) / 2
if delta_x > 0.505:
update_pos(item, 1, delta_y_steps)
elif delta_x < -0.505:
......@@ -240,9 +216,7 @@ class SchedulerEvent: # PyQt5
elif delta_y_steps != 0:
update_pos(item, 0, delta_y_steps)
def operation_mousePressEvent(
self, event: QGraphicsSceneMouseEvent
) -> None:
def operation_mousePressEvent(self, event: QGraphicsSceneMouseEvent) -> None:
"""
Changes the cursor to ClosedHandCursor when grabbing an object and
stores the current position in item's parent coordinates. *event* will
......@@ -250,14 +224,13 @@ class SchedulerEvent: # PyQt5
allows the item to receive future move, release and double-click events.
"""
item: OperationItem = self.scene().mouseGrabberItem()
self._old_op_position = self._schedule._y_locations[item.operation.graph_id]
self._signals.component_selected.emit(item.graph_id)
self._current_pos = item.mapToParent(event.pos())
self.set_item_active(item)
event.accept()
def operation_mouseReleaseEvent(
self, event: QGraphicsSceneMouseEvent
) -> None:
def operation_mouseReleaseEvent(self, event: QGraphicsSceneMouseEvent) -> None:
"""Change the cursor to OpenHandCursor when releasing an object."""
item: OperationItem = self.scene().mouseGrabberItem()
self.set_item_inactive(item)
......@@ -270,14 +243,20 @@ class SchedulerEvent: # PyQt5
if pos_x > self._schedule.schedule_time:
pos_x = pos_x % self._schedule.schedule_time
redraw = True
if self._schedule._y_locations[item.operation.graph_id] % 1:
# TODO: move other operations
self._schedule._y_locations[item.operation.graph_id] = math.ceil(
self._schedule._y_locations[item.operation.graph_id]
)
pos_y = item.y() + (OPERATION_GAP + OPERATION_HEIGHT) / 2
item.setY(pos_y)
redraw = True
if redraw:
item.setX(pos_x)
self._redraw_lines(item)
self._signals.component_moved.emit(item.graph_id)
def operation_mouseDoubleClickEvent(
self, event: QGraphicsSceneMouseEvent
) -> None:
def operation_mouseDoubleClickEvent(self, event: QGraphicsSceneMouseEvent) -> None:
...
def operation_wheelEvent(self, event: QGraphicsSceneWheelEvent) -> None:
......@@ -309,9 +288,7 @@ class SchedulerEvent: # PyQt5
elif delta_x < -0.505:
update_pos(item, -1)
def timeline_mousePressEvent(
self, event: QGraphicsSceneMouseEvent
) -> None:
def timeline_mousePressEvent(self, event: QGraphicsSceneMouseEvent) -> None:
"""
Store the current position in item's parent coordinates. *event* will
by default be accepted, and this item is then the mouse grabber. This
......@@ -324,12 +301,10 @@ class SchedulerEvent: # PyQt5
self._current_pos = item.mapToParent(event.pos())
event.accept()
def timeline_mouseReleaseEvent(
self, event: QGraphicsSceneMouseEvent
) -> None:
def timeline_mouseReleaseEvent(self, event: QGraphicsSceneMouseEvent) -> None:
"""Updates the schedule time."""
item: TimelineItem = self.scene().mouseGrabberItem()
item.hide_label()
if self._delta_time != 0:
self.set_schedule_time(self._delta_time)
self.change_schedule_time(self._delta_time)
self._signals.schedule_time_changed.emit()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
B-ASIC Scheduler-gui Graphics Graph Item Module.
B-ASIC Scheduler-GUI Scheduler Item Module.
Contains the scheduler-gui SchedulerItem class for drawing and
maintain a component in a graph.
Contains the scheduler_gui SchedulerItem class for drawing and
maintaining a schedule.
"""
from collections import defaultdict
from math import floor
......@@ -31,7 +31,9 @@ from b_asic.scheduler_gui.signal_item import SignalItem
class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
"""
A class to represent a graph in a QGraphicsScene. This class is a
A class to represent a schedule in a QGraphicsScene.
This class is a
subclass of QGraphicsItemGroup and contains the objects, axes from
AxesItem, as well as components from OperationItem. It
also inherits from SchedulerEvent, which acts as a filter for events
......@@ -53,9 +55,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
_event_items: List[QGraphicsItem]
_signal_dict: Dict[OperationItem, Set[SignalItem]]
def __init__(
self, schedule: Schedule, parent: Optional[QGraphicsItem] = None
):
def __init__(self, schedule: Schedule, parent: Optional[QGraphicsItem] = None):
"""
Construct a SchedulerItem. *parent* is passed to QGraphicsItemGroup's
constructor.
......@@ -176,27 +176,23 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
def is_valid_delta_time(self, delta_time: int) -> bool:
"""
Takes in a delta time and returns True if the schedule time can be changed by
*delta_time*. False otherwise.
Return True if the schedule time can be changed by *delta_time*.
"""
# TODO: implement
# item = self.scene().mouseGrabberItem()
if self.schedule is None:
raise ValueError("No schedule installed.")
return (
self.schedule.schedule_time + delta_time
>= self.schedule.get_max_end_time()
self.schedule.schedule_time + delta_time >= self.schedule.get_max_end_time()
)
def set_schedule_time(self, delta_time: int) -> None:
def change_schedule_time(self, delta_time: int) -> None:
"""Change the schedule time by *delta_time* and redraw the graph."""
if self._axes is None:
raise RuntimeError("No AxesItem!")
if self.schedule is None:
raise ValueError("No schedule installed.")
self.schedule.set_schedule_time(
self.schedule.schedule_time + delta_time
)
self.schedule.set_schedule_time(self.schedule.schedule_time + delta_time)
self._axes.set_width(self._axes.width + delta_time)
# Redraw all lines
self._redraw_all_lines()
......@@ -223,9 +219,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
op_item = self._operation_items[graph_id]
op_item.setPos(
self._x_axis_indent + self.schedule.start_times[graph_id],
self.schedule._get_y_position(
graph_id, OPERATION_HEIGHT, OPERATION_GAP
),
self.schedule._get_y_position(graph_id, OPERATION_HEIGHT, OPERATION_GAP),
)
def _redraw_from_start(self) -> None:
......@@ -261,9 +255,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
# build components
for graph_id in self.schedule.start_times.keys():
operation = cast(Operation, self.schedule.sfg.find_by_id(graph_id))
component = OperationItem(
operation, height=OPERATION_HEIGHT, parent=self
)
component = OperationItem(operation, height=OPERATION_HEIGHT, parent=self)
self._operation_items[graph_id] = component
self._set_position(graph_id)
self._event_items += component.event_items
......
"""
B-ASIC Scheduler-GUI Signal Item Module.
Contains the scheduler_gui SignalItem class for drawing and maintaining a signal
in the schedule.
"""
from typing import TYPE_CHECKING, Optional, cast
from qtpy.QtCore import QPointF
......@@ -24,14 +32,14 @@ class SignalItem(QGraphicsPathItem):
Parameters
----------
src_operation : OperationItem
src_operation : `~b_asic.scheduler_gui.operation_item.OperationItem`
The operation that the signal is drawn from.
dest_operation : OperationItem
dest_operation : `~b_asic.scheduler_gui.operation_item.OperationItem`
The operation that the signal is drawn to.
signal : Signal
signal : `~b_asic.signal.Signal`
The signal on the SFG level.
parent : QGraphicsItem, optional
The parent QGraphicsItem.
The parent QGraphicsItem passed to QGraphicsPathItem.
"""
_path: Optional[QPainterPath] = None
......@@ -75,7 +83,6 @@ class SignalItem(QGraphicsPathItem):
schedule = cast("SchedulerItem", self.parentItem()).schedule
if dest_x - source_x <= -0.1 or schedule._laps[self._signal.graph_id]:
offset = SCHEDULE_INDENT # TODO: Get from parent/axes...
laps = schedule._laps[self._signal.graph_id]
path.lineTo(schedule.schedule_time + offset, source_y)
path.moveTo(0 + offset, dest_y)
path.lineTo(dest_x, dest_y)
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
B-ASIC Scheduler-gui Graphics Timeline Item Module.
B-ASIC Scheduler-GUI Timeline Item Module.
Contains the scheduler-gui TimelineItem class for drawing and
maintain the timeline in a graph.
Contains the scheduler_gui TimelineItem class for drawing and
maintain the timeline in a schedule.
"""
from typing import List, Optional, overload
......@@ -21,9 +21,7 @@ class TimelineItem(QGraphicsLineItem):
_delta_time_label: QGraphicsTextItem
@overload
def __init__(
self, line: QLineF, parent: Optional[QGraphicsItem] = None
) -> None:
def __init__(self, line: QLineF, parent: Optional[QGraphicsItem] = None) -> None:
"""
Constructs a TimelineItem out of 'line'. 'parent' is passed to
QGraphicsLineItem's constructor.
......
......@@ -29,8 +29,8 @@ def wdf_allpass(
execution_time: Optional[int] = None,
) -> SFG:
"""
Generate a signal flow graph of a WDF allpass section based on symmetric two-port
adaptors.
Generate a signal flow graph of a WDF allpass section based on symmetric two-port\
adaptors.
Parameters
----------
......@@ -60,6 +60,9 @@ def wdf_allpass(
Signal flow graph
"""
np_coefficients = np.squeeze(np.asarray(coefficients))
order = len(np_coefficients)
if not order:
raise ValueError("Coefficients cannot be empty")
if np_coefficients.ndim != 1:
raise TypeError("coefficients must be a 1D-array")
if input_op is None:
......@@ -68,7 +71,6 @@ def wdf_allpass(
output = Output()
if name is None:
name = "WDF allpass section"
order = len(np_coefficients)
odd_order = order % 2
if odd_order:
# First-order section
......@@ -120,16 +122,13 @@ def direct_form_fir(
input_op: Optional[Union[Input, Signal, InputPort]] = None,
output: Optional[Union[Output, Signal, OutputPort]] = None,
name: Optional[str] = None,
mult_properties: Optional[
Union[Dict[str, int], Dict[str, Dict[str, int]]]
] = None,
add_properties: Optional[
Union[Dict[str, int], Dict[str, Dict[str, int]]]
] = None,
):
mult_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
add_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
) -> SFG:
r"""
Generate a signal flow graph of a direct form FIR filter. The *coefficients* parameter is a
sequence of impulse response values::
Generate a signal flow graph of a direct form FIR filter.
The *coefficients* parameter is a sequence of impulse response values::
coefficients = [h0, h1, h2, ..., hN]
......@@ -149,7 +148,7 @@ def direct_form_fir(
The Output to connect the SFG to. If not provided, one will be generated.
name : Name, optional
The name of the SFG. If None, "WDF allpass section".
The name of the SFG. If None, "Direct-form FIR filter".
mult_properties : dictionary, optional
Properties passed to :class:`~b_asic.core_operations.ConstantMultiplication`.
......@@ -166,6 +165,9 @@ def direct_form_fir(
transposed_direct_form_fir
"""
np_coefficients = np.squeeze(np.asarray(coefficients))
taps = len(np_coefficients)
if not taps:
raise ValueError("Coefficients cannot be empty")
if np_coefficients.ndim != 1:
raise TypeError("coefficients must be a 1D-array")
if input_op is None:
......@@ -179,7 +181,6 @@ def direct_form_fir(
if add_properties is None:
add_properties = {}
taps = len(np_coefficients)
prev_delay = input_op
prev_add = None
for i, coeff in enumerate(np_coefficients):
......@@ -202,16 +203,13 @@ def transposed_direct_form_fir(
input_op: Optional[Union[Input, Signal, InputPort]] = None,
output: Optional[Union[Output, Signal, OutputPort]] = None,
name: Optional[str] = None,
mult_properties: Optional[
Union[Dict[str, int], Dict[str, Dict[str, int]]]
] = None,
add_properties: Optional[
Union[Dict[str, int], Dict[str, Dict[str, int]]]
] = None,
):
mult_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
add_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
) -> SFG:
r"""
Generate a signal flow graph of a transposed direct form FIR filter. The *coefficients* parameter is a
sequence of impulse response values::
Generate a signal flow graph of a transposed direct form FIR filter.
The *coefficients* parameter is a sequence of impulse response values::
coefficients = [h0, h1, h2, ..., hN]
......@@ -231,7 +229,7 @@ def transposed_direct_form_fir(
The Output to connect the SFG to. If not provided, one will be generated.
name : Name, optional
The name of the SFG. If None, "WDF allpass section".
The name of the SFG. If None, "Transposed direct-form FIR filter".
mult_properties : dictionary, optional
Properties passed to :class:`~b_asic.core_operations.ConstantMultiplication`.
......@@ -248,6 +246,9 @@ def transposed_direct_form_fir(
direct_form_fir
"""
np_coefficients = np.squeeze(np.asarray(coefficients))
taps = len(np_coefficients)
if not taps:
raise ValueError("Coefficients cannot be empty")
if np_coefficients.ndim != 1:
raise TypeError("coefficients must be a 1D-array")
if input_op is None:
......@@ -261,9 +262,7 @@ def transposed_direct_form_fir(
if add_properties is None:
add_properties = {}
taps = len(np_coefficients)
prev_delay = None
prev_add = None
for i, coeff in enumerate(reversed(np_coefficients)):
tmp_mul = ConstantMultiplication(coeff, input_op, **mult_properties)
tmp_add = (
......