Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
Commits on Source (25)
Showing
with 913 additions and 496 deletions
......@@ -9,6 +9,7 @@ import logging
import os
import sys
from pprint import pprint
from typing import Optional, Tuple
from qtpy.QtCore import QFileInfo, QSize, Qt
from qtpy.QtGui import QCursor, QIcon, QKeySequence, QPainter
......@@ -33,7 +34,6 @@ from b_asic.GUI._preferences import GAP, GRID, MINBUTTONSIZE, PORTHEIGHT
from b_asic.GUI.arrow import Arrow
from b_asic.GUI.drag_button import DragButton
from b_asic.GUI.gui_interface import Ui_main_window
from b_asic.GUI.plot_window import PlotWindow
from b_asic.GUI.port_button import PortButton
from b_asic.GUI.select_sfg_window import SelectSFGWindow
from b_asic.GUI.show_pc_window import ShowPCWindow
......@@ -43,6 +43,8 @@ from b_asic.GUI.simulate_sfg_window import SimulateSFGWindow
from b_asic.GUI.util_dialogs import FaqWindow, KeybindsWindow
from b_asic.GUI.utils import decorate_class, handle_error
from b_asic.gui_utils.about_window import AboutWindow
from b_asic.gui_utils.plot_window import PlotWindow
from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort
from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.signal_flow_graph import SFG
......@@ -88,9 +90,7 @@ class MainWindow(QMainWindow):
b_asic.special_operations, self.ui.special_operations_list
)
self.shortcut_core = QShortcut(
QKeySequence("Ctrl+R"), self.ui.operation_box
)
self.shortcut_core = QShortcut(QKeySequence("Ctrl+R"), self.ui.operation_box)
self.shortcut_core.activated.connect(
self._refresh_operations_list_from_namespace
)
......@@ -140,11 +140,11 @@ class MainWindow(QMainWindow):
self.cursor = QCursor()
def init_ui(self):
def init_ui(self) -> None:
self.create_toolbar_view()
self.create_graphics_view()
def create_graphics_view(self):
def create_graphics_view(self) -> None:
self.graphic_view = QGraphicsView(self.scene, self)
self.graphic_view.setRenderHint(QPainter.Antialiasing)
self.graphic_view.setGeometry(
......@@ -152,12 +152,12 @@ class MainWindow(QMainWindow):
)
self.graphic_view.setDragMode(QGraphicsView.RubberBandDrag)
def create_toolbar_view(self):
def create_toolbar_view(self) -> None:
self.toolbar = self.addToolBar("Toolbar")
self.toolbar.addAction("Create SFG", self.create_sfg_from_toolbar)
self.toolbar.addAction("Clear workspace", self.clear_workspace)
def resizeEvent(self, event):
def resizeEvent(self, event) -> None:
self.ui.operation_box.setGeometry(
10, 10, self.ui.operation_box.width(), self.height()
)
......@@ -169,14 +169,14 @@ class MainWindow(QMainWindow):
)
super().resizeEvent(event)
def wheelEvent(self, event):
def wheelEvent(self, event) -> None:
if event.modifiers() == Qt.KeyboardModifier.ControlModifier:
old_zoom = self.zoom
self.zoom += event.angleDelta().y() / 2500
self.graphic_view.scale(self.zoom, self.zoom)
self.zoom = old_zoom
def view_operation_names(self):
def view_operation_names(self) -> None:
if self.check_show_names.isChecked():
self.is_show_names = True
else:
......@@ -186,7 +186,7 @@ class MainWindow(QMainWindow):
operation.label.setOpacity(self.is_show_names)
operation.is_show_name = self.is_show_names
def _save_work(self):
def _save_work(self) -> None:
sfg = self.sfg_widget.sfg
file_dialog = QFileDialog()
file_dialog.setDefaultSuffix(".py")
......@@ -206,14 +206,10 @@ class MainWindow(QMainWindow):
try:
with open(module, "w+") as file_obj:
file_obj.write(
sfg_to_python(
sfg, suffix=f"positions = {operation_positions}"
)
sfg_to_python(sfg, suffix=f"positions = {operation_positions}")
)
except Exception as e:
self.logger.error(
f"Failed to save SFG to path: {module}, with error: {e}."
)
self.logger.error(f"Failed to save SFG to path: {module}, with error: {e}.")
return
self.logger.info("Saved SFG to path: " + str(module))
......@@ -237,8 +233,7 @@ class MainWindow(QMainWindow):
sfg, positions = python_to_sfg(module)
except ImportError as e:
self.logger.error(
f"Failed to load module: {module} with the following error:"
f" {e}."
f"Failed to load module: {module} with the following error: {e}."
)
return
......@@ -266,12 +261,8 @@ class MainWindow(QMainWindow):
# print(op)
self.create_operation(
op,
positions[op.graph_id][0:2]
if op.graph_id in positions
else None,
positions[op.graph_id][-1]
if op.graph_id in positions
else None,
positions[op.graph_id][0:2] if op.graph_id in positions else None,
positions[op.graph_id][-1] if op.graph_id in positions else None,
)
def connect_ports(ports):
......@@ -287,9 +278,7 @@ class MainWindow(QMainWindow):
destination = [
destination
for destination in self.portDict[
self.operationDragDict[
signal.destination.operation
]
self.operationDragDict[signal.destination.operation]
]
if destination.port is signal.destination
]
......@@ -328,7 +317,7 @@ class MainWindow(QMainWindow):
self.scene.clear()
self.logger.info("Workspace cleared.")
def create_sfg_from_toolbar(self):
def create_sfg_from_toolbar(self) -> None:
inputs = []
outputs = []
for op in self.pressed_operations:
......@@ -347,14 +336,10 @@ class MainWindow(QMainWindow):
self.logger.warning("Failed to initialize SFG with empty name.")
return
self.logger.info(
"Creating SFG with name: %s from selected operations." % name
)
self.logger.info("Creating SFG with name: %s from selected operations." % name)
sfg = SFG(inputs=inputs, outputs=outputs, name=name)
self.logger.info(
"Created SFG with name: %s from selected operations." % name
)
self.logger.info("Created SFG with name: %s from selected operations." % name)
def check_equality(signal, signal_2):
if not (
......@@ -372,8 +357,7 @@ class MainWindow(QMainWindow):
and hasattr(signal_2.destination.operation, "value")
):
if not (
signal.source.operation.value
== signal_2.source.operation.value
signal.source.operation.value == signal_2.source.operation.value
and signal.destination.operation.value
== signal_2.destination.operation.value
):
......@@ -386,8 +370,7 @@ class MainWindow(QMainWindow):
and hasattr(signal_2.destination.operation, "name")
):
if not (
signal.source.operation.name
== signal_2.source.operation.name
signal.source.operation.name == signal_2.source.operation.name
and signal.destination.operation.name
== signal_2.destination.operation.name
):
......@@ -456,12 +439,12 @@ class MainWindow(QMainWindow):
self.sfg_dict[sfg.name] = sfg
def _show_precedence_graph(self, event=None):
def _show_precedence_graph(self, event=None) -> None:
self.dialog = ShowPCWindow(self)
self.dialog.add_sfg_to_dialog()
self.dialog.show()
def get_operations_from_namespace(self, namespace):
def get_operations_from_namespace(self, namespace) -> None:
self.logger.info(
"Fetching operations from namespace: " + str(namespace.__name__)
)
......@@ -471,7 +454,7 @@ class MainWindow(QMainWindow):
if hasattr(getattr(namespace, comp), "type_name")
]
def add_operations_from_namespace(self, namespace, _list):
def add_operations_from_namespace(self, namespace, _list) -> None:
for attr_name in self.get_operations_from_namespace(namespace):
attr = getattr(namespace, attr_name)
try:
......@@ -482,11 +465,9 @@ class MainWindow(QMainWindow):
except NotImplementedError:
pass
self.logger.info(
"Added operations from namespace: " + str(namespace.__name__)
)
self.logger.info("Added operations from namespace: " + str(namespace.__name__))
def add_namespace(self, event=None):
def add_namespace(self, event=None) -> None:
module, accepted = QFileDialog().getOpenFileName()
if not accepted:
return
......@@ -497,16 +478,25 @@ class MainWindow(QMainWindow):
namespace = importlib.util.module_from_spec(spec)
spec.loader.exec_module(namespace)
self.add_operations_from_namespace(
namespace, self.ui.custom_operations_list
)
self.add_operations_from_namespace(namespace, self.ui.custom_operations_list)
def create_operation(self, op, position=None, is_flipped: bool = False):
def create_operation(
self,
op: Operation,
position: Optional[Tuple[float, float]] = None,
is_flipped: bool = False,
) -> None:
"""
Parameters
----------
op : Operation
position : (float, float), optional
is_flipped : bool, default: False
"""
try:
if op in self.operationDragDict:
self.logger.warning(
"Multiple instances of operation with same name"
)
self.logger.warning("Multiple instances of operation with same name")
return
attr_button = DragButton(op.graph_id, op, True, window=self)
......@@ -573,7 +563,7 @@ class MainWindow(QMainWindow):
"Unexpected error occurred while creating operation: " + str(e)
)
def _create_operation_item(self, item):
def _create_operation_item(self, item) -> None:
self.logger.info("Creating operation of type: %s" % str(item.text()))
try:
attr_oper = self._operations_from_name[item.text()]()
......@@ -583,7 +573,7 @@ class MainWindow(QMainWindow):
"Unexpected error occurred while creating operation: " + str(e)
)
def _refresh_operations_list_from_namespace(self):
def _refresh_operations_list_from_namespace(self) -> None:
self.logger.info("Refreshing operation list.")
self.ui.core_operations_list.clear()
self.ui.special_operations_list.clear()
......@@ -596,10 +586,10 @@ class MainWindow(QMainWindow):
)
self.logger.info("Finished refreshing operation list.")
def on_list_widget_item_clicked(self, item):
def on_list_widget_item_clicked(self, item) -> None:
self._create_operation_item(item)
def keyPressEvent(self, event):
def keyPressEvent(self, event) -> None:
if event.key() == Qt.Key.Key_Delete:
for pressed_op in self.pressed_operations:
pressed_op.remove()
......@@ -607,11 +597,10 @@ class MainWindow(QMainWindow):
self.pressed_operations.clear()
super().keyPressEvent(event)
def _connect_callback(self, *event):
def _connect_callback(self, *event) -> None:
if len(self.pressed_ports) < 2:
self.logger.warning(
"Cannot connect less than two ports. Please select at least"
" two."
"Cannot connect less than two ports. Please select at least two."
)
return
......@@ -636,9 +625,7 @@ class MainWindow(QMainWindow):
for port in self.pressed_ports:
port.select_port()
def _connect_button(
self, source: PortButton, destination: PortButton
) -> None:
def _connect_button(self, source: PortButton, destination: PortButton) -> None:
"""
Connect two PortButtons with an Arrow.
......@@ -698,24 +685,18 @@ class MainWindow(QMainWindow):
def _simulate_sfg(self):
for sfg, properties in self.dialog.properties.items():
self.logger.info("Simulating SFG with name: %s" % str(sfg.name))
simulation = FastSimulation(
sfg, input_providers=properties["input_values"]
)
simulation = FastSimulation(sfg, input_providers=properties["input_values"])
l_result = simulation.run_for(
properties["iteration_count"],
save_results=properties["all_results"],
)
print(f"{'=' * 10} {sfg.name} {'=' * 10}")
pprint(
simulation.results if properties["all_results"] else l_result
)
pprint(simulation.results if properties["all_results"] else l_result)
print(f"{'=' * 10} /{sfg.name} {'=' * 10}")
if properties["show_plot"]:
self.logger.info(
"Opening plot for SFG with name: " + str(sfg.name)
)
self.logger.info("Opening plot for SFG with name: " + str(sfg.name))
self.logger.info(
"To save the plot press 'Ctrl+S' when the plot is focused."
)
......
"""
B-ASIC select SFG window.
"""
from typing import TYPE_CHECKING
from qtpy.QtCore import Qt, Signal
from qtpy.QtWidgets import QComboBox, QDialog, QPushButton, QVBoxLayout
if TYPE_CHECKING:
from b_asic.GUI.main_window import MainWindow
class SelectSFGWindow(QDialog):
ok = Signal()
def __init__(self, window):
def __init__(self, window: "MainWindow"):
super().__init__()
self._window = window
self.setWindowFlags(Qt.WindowTitleHint | Qt.WindowCloseButtonHint)
......@@ -23,9 +27,8 @@ class SelectSFGWindow(QDialog):
self.sfg = None
self.setLayout(self.dialog_layout)
self.add_sfgs_to_layout()
def add_sfgs_to_layout(self):
# Add SFGs to layout
for sfg in self._window.sfg_dict:
self.combo_box.addItem(sfg)
......
# -*- coding: utf-8 -*-
from qtpy.QtWidgets import QGridLayout, QLabel, QLineEdit, QSpinBox
from b_asic.signal_generator import (
Constant,
Gaussian,
Impulse,
SignalGenerator,
Sinusoid,
Step,
Uniform,
ZeroPad,
)
class SignalGeneratorInput(QGridLayout):
"""Abstract class for graphically configuring and generating signal generators."""
def __init__(self, logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logger = logger
def get_generator(self) -> SignalGenerator:
"""Return the SignalGenerator based on the graphical input."""
raise NotImplementedError
class DelayInput(SignalGeneratorInput):
"""
Abstract class for graphically configuring and generating signal generators that
have a single delay parameter.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.delay_label = QLabel("Delay")
self.addWidget(self.delay_label, 0, 0)
self.delay_spin_box = QSpinBox()
self.delay_spin_box.setRange(0, 2147483647)
self.addWidget(self.delay_spin_box, 0, 1)
def get_generator(self) -> SignalGenerator:
raise NotImplementedError
class ImpulseInput(DelayInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.Impulse` signal generator.
"""
def get_generator(self) -> SignalGenerator:
return Impulse(self.delay_spin_box.value())
class StepInput(DelayInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.Step` signal generator.
"""
def get_generator(self) -> SignalGenerator:
return Step(self.delay_spin_box.value())
class ZeroPadInput(SignalGeneratorInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.ZeroPad` signal generator.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_label = QLabel("Input")
self.addWidget(self.input_label, 0, 0)
self.input_sequence = QLineEdit()
self.addWidget(self.input_sequence, 0, 1)
def get_generator(self) -> SignalGenerator:
input_values = []
for val in self.input_sequence.text().split(","):
val = val.strip()
try:
if not val:
val = 0
val = complex(val)
except ValueError:
self._logger.warning(f"Skipping value: {val}, not a digit.")
continue
input_values.append(val)
return ZeroPad(input_values)
class SinusoidInput(SignalGeneratorInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.Sinusoid` signal generator.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.frequency_label = QLabel("Frequency")
self.addWidget(self.frequency_label, 0, 0)
self.frequency_input = QLineEdit()
self.addWidget(self.frequency_input, 0, 1)
self.phase_label = QLabel("Phase")
self.addWidget(self.phase_label, 1, 0)
self.phase_input = QLineEdit()
self.addWidget(self.phase_input, 1, 1)
def get_generator(self) -> SignalGenerator:
frequency = self.frequency_input.text().strip()
try:
if not frequency:
frequency = 0.1
frequency = float(frequency)
except ValueError:
self._logger.warning(f"Cannot parse frequency: {frequency} not a number.")
frequency = 0.1
phase = self.phase_input.text().strip()
try:
if not phase:
phase = 0
phase = float(phase)
except ValueError:
self._logger.warning(f"Cannot parse phase: {phase} not a number.")
phase = 0
return Sinusoid(frequency, phase)
class GaussianInput(SignalGeneratorInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.Gaussian` signal generator.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scale_label = QLabel("Standard deviation")
self.addWidget(self.scale_label, 0, 0)
self.scale_input = QLineEdit()
self.scale_input.setText("1.0")
self.addWidget(self.scale_input, 0, 1)
self.loc_label = QLabel("Average value")
self.addWidget(self.loc_label, 1, 0)
self.loc_input = QLineEdit()
self.loc_input.setText("0.0")
self.addWidget(self.loc_input, 1, 1)
self.seed_label = QLabel("Seed")
self.addWidget(self.seed_label, 2, 0)
self.seed_spin_box = QSpinBox()
self.seed_spin_box.setRange(0, 2147483647)
self.addWidget(self.seed_spin_box, 2, 1)
def get_generator(self) -> SignalGenerator:
scale = self.scale_input.text().strip()
try:
if not scale:
scale = 1
scale = float(scale)
except ValueError:
self._logger.warning(f"Cannot parse scale: {scale} not a number.")
scale = 1
loc = self.loc_input.text().strip()
try:
if not loc:
loc = 0
loc = float(loc)
except ValueError:
self._logger.warning(f"Cannot parse loc: {loc} not a number.")
loc = 0
return Gaussian(self.seed_spin_box.value(), loc, scale)
class UniformInput(SignalGeneratorInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.Uniform` signal generator.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.low_label = QLabel("Lower bound")
self.addWidget(self.low_label, 0, 0)
self.low_input = QLineEdit()
self.low_input.setText("-1.0")
self.addWidget(self.low_input, 0, 1)
self.high_label = QLabel("Upper bound")
self.addWidget(self.high_label, 1, 0)
self.high_input = QLineEdit()
self.high_input.setText("1.0")
self.addWidget(self.high_input, 1, 1)
self.seed_label = QLabel("Seed")
self.addWidget(self.seed_label, 2, 0)
self.seed_spin_box = QSpinBox()
self.seed_spin_box.setRange(0, 2147483647)
self.addWidget(self.seed_spin_box, 2, 1)
def get_generator(self) -> SignalGenerator:
low = self.low_input.text().strip()
try:
if not low:
low = -1.0
low = float(low)
except ValueError:
self._logger.warning(f"Cannot parse low: {low} not a number.")
low = -1.0
high = self.high_input.text().strip()
try:
if not high:
high = 1.0
high = float(high)
except ValueError:
self._logger.warning(f"Cannot parse high: {high} not a number.")
high = 1.0
return Uniform(self.seed_spin_box.value(), low, high)
class ConstantInput(SignalGeneratorInput):
"""
Class for graphically configuring and generating a
:class:`~b_asic.signal_generators.Constant` signal generator.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.constant_label = QLabel("Constant")
self.addWidget(self.constant_label, 0, 0)
self.constant_input = QLineEdit()
self.constant_input.setText("1.0")
self.addWidget(self.constant_input, 0, 1)
def get_generator(self) -> SignalGenerator:
constant = self.constant_input.text().strip()
try:
if not constant:
constant = 1.0
constant = complex(constant)
except ValueError:
self._logger.warning(f"Cannot parse constant: {constant} not a number.")
constant = 0.0
return Constant(constant)
_GENERATOR_MAPPING = {
"Constant": ConstantInput,
"Gaussian": GaussianInput,
"Impulse": ImpulseInput,
"Sinusoid": SinusoidInput,
"Step": StepInput,
"Uniform": UniformInput,
"ZeroPad": ZeroPadInput,
}
......@@ -2,9 +2,7 @@
B-ASIC window to simulate an SFG.
"""
import numpy as np
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from qtpy.QtCore import Qt, Signal
from qtpy.QtGui import QKeySequence
......@@ -26,7 +24,7 @@ from qtpy.QtWidgets import (
QVBoxLayout,
)
from b_asic.signal_generator import Impulse, Step, ZeroPad
from b_asic.GUI.signal_generator_input import _GENERATOR_MAPPING
class SimulateSFGWindow(QDialog):
......@@ -58,12 +56,15 @@ class SimulateSFGWindow(QDialog):
spin_box = QSpinBox()
spin_box.setRange(0, 2147483647)
spin_box.setValue(100)
options_layout.addRow("Iteration count: ", spin_box)
check_box_plot = QCheckBox()
check_box_plot.setCheckState(Qt.CheckState.Checked)
options_layout.addRow("Plot results: ", check_box_plot)
check_box_all = QCheckBox()
check_box_all.setCheckState(Qt.CheckState.Checked)
options_layout.addRow("Get all results: ", check_box_all)
sfg_layout.addLayout(options_layout)
......@@ -89,14 +90,12 @@ class SimulateSFGWindow(QDialog):
input_dropdown = QComboBox()
input_dropdown.insertItems(
0, ["Impulse", "Step", "Input", "File"]
0, list(_GENERATOR_MAPPING.keys()) + ["File"]
)
input_dropdown.currentTextChanged.connect(
lambda text, i=i: self.change_input_format(i, text)
)
self.input_grid.addWidget(
input_dropdown, i, 1, alignment=Qt.AlignLeft
)
self.input_grid.addWidget(input_dropdown, i, 1, alignment=Qt.AlignLeft)
self.change_input_format(i, "Impulse")
......@@ -124,27 +123,8 @@ class SimulateSFGWindow(QDialog):
param_grid = QGridLayout()
if text == "Impulse":
delay_label = QLabel("Delay")
param_grid.addWidget(delay_label, 0, 0)
delay_spin_box = QSpinBox()
delay_spin_box.setRange(0, 2147483647)
param_grid.addWidget(delay_spin_box, 0, 1)
elif text == "Step":
delay_label = QLabel("Delay")
param_grid.addWidget(delay_label, 0, 0)
delay_spin_box = QSpinBox()
delay_spin_box.setRange(0, 2147483647)
param_grid.addWidget(delay_spin_box, 0, 1)
elif text == "Input":
input_label = QLabel("Input")
param_grid.addWidget(input_label, 0, 0)
input_sequence = QLineEdit()
param_grid.addWidget(input_sequence, 0, 1)
zpad_label = QLabel("Zpad")
param_grid.addWidget(zpad_label, 1, 0)
zpad_button = QCheckBox()
param_grid.addWidget(zpad_button, 1, 1)
if text in _GENERATOR_MAPPING:
param_grid = _GENERATOR_MAPPING[text](self._window.logger)
elif text == "File":
file_label = QLabel("Browse")
param_grid.addWidget(file_label, 0, 0)
......@@ -177,9 +157,7 @@ class SimulateSFGWindow(QDialog):
_list_values.append(complex(val))
except ValueError:
self._window.logger.warning(
f"Skipping value: {val}, not a digit."
)
self._window.logger.warning(f"Skipping value: {val}, not a digit.")
continue
_input_values.append(_list_values)
......@@ -192,89 +170,40 @@ class SimulateSFGWindow(QDialog):
if ic_value == 0:
self._window.logger.error("Iteration count is set to zero.")
tmp = []
input_values = []
for i in range(self.input_grid.rowCount()):
in_format = (
self.input_grid.itemAtPosition(i, 1).widget().currentText()
)
in_format = self.input_grid.itemAtPosition(i, 1).widget().currentText()
in_param = self.input_grid.itemAtPosition(i, 2)
tmp2 = []
if in_format == "Impulse":
g = Impulse(in_param.itemAtPosition(0, 1).widget().value())
for j in range(ic_value):
tmp2.append(str(g(j)))
elif in_format == "Step":
g = Step(in_param.itemAtPosition(0, 1).widget().value())
for j in range(ic_value):
tmp2.append(str(g(j)))
elif in_format == "Input":
widget = in_param.itemAtPosition(0, 1).widget()
tmp3 = widget.text().split(",")
if in_param.itemAtPosition(1, 1).widget().isChecked():
g = ZeroPad(tmp3)
for j in range(ic_value):
tmp2.append(str(g(j)))
else:
tmp2 = tmp3
if in_format in _GENERATOR_MAPPING:
tmp2 = in_param.get_generator()
elif in_format == "File":
widget = in_param.itemAtPosition(0, 1).widget()
path = widget.text()
try:
tmp2 = np.loadtxt(path, dtype=str).tolist()
except FileNotFoundError:
self._window.logger.error(
f"Selected input file not found."
tmp2 = self.parse_input_values(
np.loadtxt(path, dtype=str).tolist()
)
except FileNotFoundError:
self._window.logger.error(f"Selected input file not found.")
continue
else:
raise Exception("Input selection is not implemented")
tmp.append(tmp2)
input_values.append(tmp2)
input_values = self.parse_input_values(tmp)
self.properties[sfg] = {
"iteration_count": ic_value,
"show_plot": self.input_fields[sfg]["show_plot"].isChecked(),
"all_results": self.input_fields[sfg]["all_results"].isChecked(),
"input_values": input_values,
}
max_len = max(len(list_) for list_ in input_values)
min_len = min(len(list_) for list_ in input_values)
if max_len != min_len:
self._window.logger.error(
"Minimum length of input lists are not equal to maximum "
f"length of input lists: {max_len} != {min_len}."
)
elif ic_value > min_len:
self._window.logger.error(
"Minimum length of input lists are less than the "
f"iteration count: {ic_value} > {min_len}."
)
else:
self.properties[sfg] = {
"iteration_count": ic_value,
"show_plot": self.input_fields[sfg][
"show_plot"
].isChecked(),
"all_results": self.input_fields[sfg][
"all_results"
].isChecked(),
"input_values": input_values,
}
# If we plot we should also print the entire data,
# since you cannot really interact with the graph.
if self.properties[sfg]["show_plot"]:
self.properties[sfg]["all_results"] = True
continue
self._window.logger.info(
f"Skipping simulation of SFG with name: {sfg.name}, "
"due to previous errors."
)
# If we plot we should also print the entire data,
# since you cannot really interact with the graph.
if self.properties[sfg]["show_plot"]:
self.properties[sfg]["all_results"] = True
self.accept()
self.simulate.emit()
......@@ -296,9 +225,7 @@ class Plot(FigureCanvas):
FigureCanvas.__init__(self, fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(
self, QSizePolicy.Expanding, QSizePolicy.Expanding
)
FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.save_figure = QShortcut(QKeySequence("Ctrl+S"), self)
self.save_figure.activated.connect(self._save_plot_figure)
......@@ -307,18 +234,14 @@ class Plot(FigureCanvas):
def _save_plot_figure(self):
self._window.logger.info(f"Saving plot of figure: {self.sfg.name}.")
file_choices = "PNG (*.png)|*.png"
path, ext = QFileDialog.getSaveFileName(
self, "Save file", "", file_choices
)
path, ext = QFileDialog.getSaveFileName(self, "Save file", "", file_choices)
path = path.encode("utf-8")
if not path[-4:] == file_choices[-4:].encode("utf-8"):
path += file_choices[-4:].encode("utf-8")
if path:
self.print_figure(path.decode(), dpi=self.dpi)
self._window.logger.info(
f"Saved plot: {self.sfg.name} to path: {path}."
)
self._window.logger.info(f"Saved plot: {self.sfg.name} to path: {path}.")
def _plot_values_sfg(self):
x_axis = list(range(len(self.simulation.results["0"])))
......
......@@ -62,6 +62,24 @@ class Constant(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
def __repr__(self) -> str:
return f"Constant({self.value})"
def __str__(self) -> str:
return f"{self.value}"
@property
def is_linear(self) -> bool:
return True
@property
def is_constant(self) -> bool:
return True
class Addition(AbstractOperation):
"""
......@@ -125,6 +143,10 @@ class Addition(AbstractOperation):
def evaluate(self, a, b):
return a + b
@property
def is_linear(self) -> bool:
return True
class Subtraction(AbstractOperation):
"""
......@@ -185,6 +207,10 @@ class Subtraction(AbstractOperation):
def evaluate(self, a, b):
return a - b
@property
def is_linear(self) -> bool:
return True
class AddSub(AbstractOperation):
r"""
......@@ -266,6 +292,10 @@ class AddSub(AbstractOperation):
"""Set if operation is an addition."""
self.set_param("is_add", is_add)
@property
def is_linear(self) -> bool:
return True
class Multiplication(AbstractOperation):
r"""
......@@ -327,6 +357,12 @@ class Multiplication(AbstractOperation):
def evaluate(self, a, b):
return a * b
@property
def is_linear(self) -> bool:
return any(
input.connected_source.operation.is_constant for input in self.inputs
)
class Division(AbstractOperation):
r"""
......@@ -368,6 +404,10 @@ class Division(AbstractOperation):
def evaluate(self, a, b):
return a / b
@property
def is_linear(self) -> bool:
return self.input(1).connected_source.operation.is_constant
class Min(AbstractOperation):
r"""
......@@ -410,9 +450,7 @@ class Min(AbstractOperation):
def evaluate(self, a, b):
if isinstance(a, complex) or isinstance(b, complex):
raise ValueError(
"core_operations.Min does not support complex numbers."
)
raise ValueError("core_operations.Min does not support complex numbers.")
return a if a < b else b
......@@ -457,9 +495,7 @@ class Max(AbstractOperation):
def evaluate(self, a, b):
if isinstance(a, complex) or isinstance(b, complex):
raise ValueError(
"core_operations.Max does not support complex numbers."
)
raise ValueError("core_operations.Max does not support complex numbers.")
return a if a > b else b
......@@ -589,8 +625,7 @@ class ConstantMultiplication(AbstractOperation):
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct a ConstantMultiplication operation with the given value.
"""
"""Construct a ConstantMultiplication operation with the given value."""
super().__init__(
input_count=1,
output_count=1,
......@@ -619,6 +654,10 @@ class ConstantMultiplication(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def is_linear(self) -> bool:
return True
class Butterfly(AbstractOperation):
r"""
......@@ -661,6 +700,10 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
@property
def is_linear(self) -> bool:
return True
class MAD(AbstractOperation):
r"""
......@@ -700,6 +743,13 @@ class MAD(AbstractOperation):
def evaluate(self, a, b, c):
return a * b + c
@property
def is_linear(self) -> bool:
return (
self.input(0).connected_source.operation.is_constant
or self.input(1).connected_source.operation.is_constant
)
class SymmetricTwoportAdaptor(AbstractOperation):
r"""
......@@ -752,6 +802,10 @@ class SymmetricTwoportAdaptor(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def is_linear(self) -> bool:
return True
class Reciprocal(AbstractOperation):
r"""
......
"""PlotWindow is a window in which simulation results are plotted."""
# TODO's:
# * Solve the legend update. That isn't working at all.
# * Zoom etc. Might need to change FigureCanvas. Or just something very little.
# * Add a function to run this as a "stand-alone".
import re
import sys
from typing import Dict, List, Optional, Tuple
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from qtpy.QtCore import Qt
from qtpy.QtGui import QKeySequence
# Intereme imports for the Plot class:
from qtpy.QtWidgets import ( # QFrame,; QScrollArea,; QLineEdit,; QSizePolicy,; QLabel,
from qtpy.QtWidgets import ( # QFrame,; QScrollArea,; QLineEdit,; QSizePolicy,; QLabel,; QFileDialog,; QShortcut,
QApplication,
QCheckBox,
QDialog,
QFileDialog,
QHBoxLayout,
QListWidget,
QListWidgetItem,
QPushButton,
QShortcut,
QSizePolicy,
QVBoxLayout,
)
class PlotCanvas(FigureCanvas):
"""PlotCanvas is used as a part in the PlotWindow."""
def __init__(self, logger, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
super().__init__(fig)
self.axes = fig.add_subplot(111)
self.axes.xaxis.set_major_locator(MaxNLocator(integer=True))
self.legend = None
self.logger = logger
FigureCanvas.updateGeometry(self)
self.save_figure = QShortcut(QKeySequence("Ctrl+S"), self)
self.save_figure.activated.connect(self._save_plot_figure)
def _save_plot_figure(self):
self.logger.info(f"Saving plot of figure: {self.sfg.name}.")
file_choices = "PNG (*.png)|*.png"
path, ext = QFileDialog.getSaveFileName(
self, "Save file", "", file_choices
)
path = path.encode("utf-8")
if not path[-4:] == file_choices[-4:].encode("utf-8"):
path += file_choices[-4:].encode("utf-8")
if path:
self.print_figure(path.decode(), dpi=self.dpi)
self.logger.info(f"Saved plot: {self.sfg.name} to path: {path}.")
class PlotWindow(QDialog):
"""Dialog for plotting the result of a simulation."""
def __init__(
self,
sim_result,
# sfg_name="{sfg_name}",
# window=None,
sim_result: Dict[str, List[complex]],
logger=print,
sfg_name: Optional[str] = None,
parent=None,
# width=5,
# height=4,
# dpi=100,
):
super().__init__(parent=parent)
# self._window = window
self.setWindowFlags(
Qt.WindowTitleHint
| Qt.WindowCloseButtonHint
| Qt.WindowMinimizeButtonHint
| Qt.WindowMaximizeButtonHint
| Qt.WindowStaysOnTopHint
)
self.setWindowTitle("Simulation result")
self.sim_result = sim_result
title = (
f"Simulation results: {sfg_name}"
if sfg_name is not None
else "Simulation results"
)
self.setWindowTitle(title)
self._auto_redraw = False
# Categorise sim_results into inputs, outputs, delays, others
......@@ -102,7 +70,7 @@ class PlotWindow(QDialog):
sim_res_others[key] = sim_result[key]
# Layout: ############################################
# | list | |
# | list | icons |
# | ... | plot |
# | misc | |
......@@ -110,25 +78,30 @@ class PlotWindow(QDialog):
self.setLayout(self.dialog_layout)
listlayout = QVBoxLayout()
self.plotcanvas = PlotCanvas(
logger=logger, parent=self, width=5, height=4, dpi=100
)
plotlayout = QVBoxLayout()
self.dialog_layout.addLayout(listlayout)
self.dialog_layout.addWidget(self.plotcanvas)
self.dialog_layout.addLayout(plotlayout)
########### Plot: ##############
# Do this before the list layout, as the list layout will re/set visibility
# Note: The order is of importens. Interesting lines last, to be on top.
# self.plotcanvas = PlotCanvas(
# logger=logger, parent=self, width=5, height=4, dpi=100
# )
self._plot_fig = Figure(figsize=(5, 4), layout="compressed")
self._plot_axes = self._plot_fig.add_subplot(111)
self._plot_axes.xaxis.set_major_locator(MaxNLocator(integer=True))
self._lines = {}
for key in (
sim_res_others | sim_res_delays | sim_res_ins | sim_res_outs
):
line = self.plotcanvas.axes.plot(
sim_result[key], visible=False, label=key
)
self._lines[key] = line
self.plotcanvas.legend = self.plotcanvas.axes.legend()
for key in sim_res_others | sim_res_delays | sim_res_ins | sim_res_outs:
line = self._plot_axes.plot(sim_result[key], label=key)
self._lines[key] = line[0]
self._plot_canvas = FigureCanvas(self._plot_fig)
plotlayout.addWidget(NavigationToolbar(self._plot_canvas, self))
plotlayout.addWidget(self._plot_canvas)
########### List layout: ##############
......@@ -144,11 +117,10 @@ class PlotWindow(QDialog):
# Add the entire list
self.checklist = QListWidget()
self.checklist.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
self.checklist.itemChanged.connect(self._item_change)
listitems = {}
for key in (
sim_res_ins | sim_res_outs | sim_res_delays | sim_res_others
):
for key in sim_res_ins | sim_res_outs | sim_res_delays | sim_res_others:
listitem = QListWidgetItem(key)
listitems[key] = listitem
self.checklist.addItem(listitem)
......@@ -157,10 +129,11 @@ class PlotWindow(QDialog):
)
for key in sim_res_outs:
listitems[key].setCheckState(Qt.CheckState.Checked)
self.checklist.setFixedWidth(150)
# self.checklist.setFixedWidth(150)
listlayout.addWidget(self.checklist)
# Add additional checkboxes
self._legend = self._plot_axes.legend()
self.legend_checkbox = QCheckBox("&Legend")
self.legend_checkbox.stateChanged.connect(self._legend_checkbox_change)
self.legend_checkbox.setCheckState(Qt.CheckState.Checked)
......@@ -180,13 +153,11 @@ class PlotWindow(QDialog):
self._auto_redraw = True
def _legend_checkbox_change(self, checkState):
self.plotcanvas.legend.set(
visible=(checkState == Qt.CheckState.Checked)
)
self._legend.set(visible=(checkState == Qt.CheckState.Checked))
if self._auto_redraw:
if checkState == Qt.CheckState.Checked:
self.plotcanvas.legend = self.plotcanvas.axes.legend()
self.plotcanvas.draw()
self._legend = self._plot_axes.legend()
self._plot_canvas.draw()
# def _ontop_checkbox_change(self, checkState):
# Bugg: It seems the window closes if you change the WindowStaysOnTopHint.
......@@ -200,30 +171,50 @@ class PlotWindow(QDialog):
for x in range(self.checklist.count()):
self.checklist.item(x).setCheckState(Qt.CheckState.Checked)
self._auto_redraw = True
self.plotcanvas.draw()
self._update_legend()
def _update_legend(self):
self._legend = self._plot_axes.legend()
self._plot_canvas.draw()
def _button_none_click(self, event):
self._auto_redraw = False
for x in range(self.checklist.count()):
self.checklist.item(x).setCheckState(Qt.CheckState.Unchecked)
self._auto_redraw = True
self.plotcanvas.draw()
self._update_legend()
def _item_change(self, listitem):
key = listitem.text()
self._lines[key][0].set(
visible=(listitem.checkState() == Qt.CheckState.Checked)
)
if listitem.checkState() == Qt.CheckState.Checked:
self._plot_axes.add_line(self._lines[key])
else:
self._lines[key].remove()
if self._auto_redraw:
if self.legend_checkbox.checkState == Qt.CheckState.Checked:
self.plotcanvas.legend = self.plotcanvas.axes.legend()
self.plotcanvas.draw()
self._update_legend()
def start_simulation_dialog(
sim_results: Dict[str, List[complex]], sfg_name: Optional[str] = None
):
"""
Display the simulation results window.
Parameters
----------
sim_results : dict
Simulation results of the form obtained from :attr:`~b_asic.simulation.Simulation.results`.
sfg_name : str, optional
DESCRIPTION. The default is None.
"""
app = QApplication(sys.argv)
win = PlotWindow(sim_result=sim_results, sfg_name=sfg_name)
win.exec_()
# Simple test of the dialog
if __name__ == "__main__":
app = QApplication(sys.argv)
# sim_res = {"c1": [3, 6, 7], "c2": [4, 5, 5], "bfly1.0": [7, 0, 0], "bfly1.1": [-1, 0, 2], "0": [1, 2, 3]}
sim_res = {
'0': [0.5, 0.5, 0, 0],
'add1': [0.5, 0.5, 0, 0],
......@@ -232,9 +223,4 @@ if __name__ == "__main__":
'in1': [1, 0, 0, 0],
't1': [0, 1, 0, 0],
}
win = PlotWindow(
# window=None, sim_result=sim_res, sfg_name="hej", logger=print
sim_result=sim_res,
)
win.exec_()
# win.show()
start_simulation_dialog(sim_res, "Test data")
......@@ -25,12 +25,7 @@ from typing import (
overload,
)
from b_asic.graph_component import (
AbstractGraphComponent,
GraphComponent,
GraphID,
Name,
)
from b_asic.graph_component import AbstractGraphComponent, GraphComponent, GraphID, Name
from b_asic.port import InputPort, OutputPort, SignalSourceProvider
from b_asic.signal import Signal
from b_asic.types import Num
......@@ -403,9 +398,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
"""
Return a tuple containing coordinates for the two polygons outlining
the latency and execution time of the operation.
......@@ -413,24 +406,6 @@ class Operation(GraphComponent, SignalSourceProvider):
"""
raise NotImplementedError
@abstractmethod
def get_io_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
"""
Return a tuple containing coordinates for inputs and outputs, respectively.
These maps to the polygons and are corresponding to a start time of 0
and height 1.
See also
========
get_input_coordinates
get_output_coordinates
"""
raise NotImplementedError
@abstractmethod
def get_input_coordinates(
self,
......@@ -442,7 +417,6 @@ class Operation(GraphComponent, SignalSourceProvider):
See also
========
get_io_coordinates
get_output_coordinates
"""
raise NotImplementedError
......@@ -459,7 +433,6 @@ class Operation(GraphComponent, SignalSourceProvider):
See also
========
get_input_coordinates
get_io_coordinates
"""
raise NotImplementedError
......@@ -494,6 +467,22 @@ class Operation(GraphComponent, SignalSourceProvider):
def _check_all_latencies_set(self) -> None:
raise NotImplementedError
@property
@abstractmethod
def is_linear(self) -> bool:
"""
Return True if the operation is linear.
"""
raise NotImplementedError
@property
@abstractmethod
def is_constant(self) -> bool:
"""
Return True if the output of the operation is constant.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""
......@@ -512,9 +501,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
input_count: int,
output_count: int,
name: Name = Name(""),
input_sources: Optional[
Sequence[Optional[SignalSourceProvider]]
] = None,
input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None,
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
......@@ -575,9 +562,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
@overload
@abstractmethod
def evaluate(
self, *inputs: Num
) -> List[Num]: # pylint: disable=arguments-differ
def evaluate(self, *inputs: Num) -> List[Num]: # pylint: disable=arguments-differ
...
@abstractmethod
......@@ -601,34 +586,25 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
return Addition(
Constant(src) if isinstance(src, Number) else src, self
)
return Addition(Constant(src) if isinstance(src, Number) else src, self)
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
self, Constant(src) if isinstance(src, Number) else src
)
return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
Constant(src) if isinstance(src, Number) else src, self
)
return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
def __mul__(
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import (
ConstantMultiplication,
Multiplication,
)
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
......@@ -640,10 +616,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import (
ConstantMultiplication,
Multiplication,
)
from b_asic.core_operations import ConstantMultiplication, Multiplication
return (
ConstantMultiplication(src, self)
......@@ -655,9 +628,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(
self, Constant(src) if isinstance(src, Number) else src
)
return Division(self, Constant(src) if isinstance(src, Number) else src)
def __rtruediv__(
self, src: Union[SignalSourceProvider, Num]
......@@ -835,8 +806,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Sequence[Optional[Num]]:
return [
self.current_output(i, delays, prefix)
for i in range(self.output_count)
self.current_output(i, delays, prefix) for i in range(self.output_count)
]
def evaluate_outputs(
......@@ -927,9 +897,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
Operations input ports.
"""
return [
signal.source.operation
for signal in self.input_signals
if signal.source
signal.source.operation for signal in self.input_signals if signal.source
]
@property
......@@ -1008,10 +976,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max(
(
(
cast(int, output.latency_offset)
- cast(int, input.latency_offset)
)
(cast(int, output.latency_offset) - cast(int, input.latency_offset))
for output, input in it.product(self.outputs, self.inputs)
)
)
......@@ -1039,7 +1004,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for inputs"
"Missing latencies for input(s)"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
......@@ -1050,8 +1015,8 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for outputs"
f" {[i for i in latency_offsets if i is not None]}"
"Missing latencies for output(s)"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
return cast(List[int], latency_offsets)
......@@ -1116,9 +1081,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
# Doc-string inherited
return (
self._get_plot_coordinates_for_latency(),
......@@ -1169,28 +1132,34 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited
num_in = self.input_count
return tuple(
(
self.input_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.inputs)),
(1 + 2 * k) / (2 * num_in),
)
for k in range(len(self.inputs))
for k in range(num_in)
)
def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited
num_out = self.output_count
return tuple(
(
self.output_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.outputs)),
(1 + 2 * k) / (2 * num_out),
)
for k in range(len(self.outputs))
for k in range(num_out)
)
def get_io_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
# Doc-string inherited
return self.get_input_coordinates(), self.get_output_coordinates()
@property
def is_linear(self) -> bool:
if self.is_constant:
return True
return False
@property
def is_constant(self) -> bool:
return all(
input.connected_source.operation.is_constant for input in self.inputs
)
......@@ -58,7 +58,7 @@ class Process:
return self._name
def __repr__(self) -> str:
return f"Process({self.start_time}, {self.execution_time}, {self.name})"
return f"Process({self.start_time}, {self.execution_time}, {self.name!r})"
# Static counter for default names
_name_cnt = 0
......@@ -166,7 +166,8 @@ class PlainMemoryVariable(Process):
Identifier for the source of the memory variable.
reads : {int: int, ...}
Dictionary where the key is the destination identifier and the value
is the time after *write_time* that the memory variable is read.
is the time after *write_time* that the memory variable is read, i.e., the
lifetime of the variable.
name : str, optional
The name of the process.
"""
......
......@@ -120,7 +120,7 @@ class ProcessCollection:
return self._collection
def __len__(self):
return len(self.__collection__)
return len(self._collection)
def add_process(self, process: Process):
"""
......
"""
B-ASIC Save/Load Structure Module.
Contains functions for saving/loading SFGs to/from strings that can be stored
as files.
Contains functions for saving/loading SFGs and Schedules to/from strings that can be
stored as files.
"""
from datetime import datetime
......@@ -11,11 +11,12 @@ from typing import Dict, Optional, Tuple, cast
from b_asic.graph_component import GraphComponent
from b_asic.port import InputPort
from b_asic.schedule import Schedule
from b_asic.signal_flow_graph import SFG
def sfg_to_python(
sfg: SFG, counter: int = 0, suffix: Optional[str] = None
sfg: SFG, counter: int = 0, suffix: Optional[str] = None, schedule=False
) -> str:
"""
Given an SFG structure try to serialize it for saving to a file.
......@@ -23,15 +24,23 @@ def sfg_to_python(
Parameters
==========
sfg : SFG
The SFG to serialize
The SFG to serialize.
counter : int, default: 0
Number used for naming the SFG. Enables SFGs in SFGs.
suffix : str, optional
String to append at the end of the result.
schedule : bool, default: False
True if printing a schedule.
"""
if not isinstance(sfg, SFG):
raise TypeError("An SFG must be provided")
_type = "Schedule" if schedule else "SFG"
result = (
'\n"""\nB-ASIC automatically generated SFG file.\n'
'\n"""\n'
+ f"B-ASIC automatically generated {_type} file.\n"
+ "Name: "
+ f"{sfg.name}"
+ "\n"
......@@ -44,6 +53,8 @@ def sfg_to_python(
result += "\nfrom b_asic import SFG, Signal, Input, Output"
for op_type in {type(op) for op in sfg.operations}:
result += f", {op_type.__name__}"
if schedule:
result += ", Schedule"
def kwarg_unpacker(comp: GraphComponent, params=None) -> str:
if params is None:
......@@ -61,56 +72,51 @@ def sfg_to_python(
params = {k: v for k, v in params.items() if v}
if params.get("latency_offsets", None) is not None:
params["latency_offsets"] = {
k: v
for k, v in params["latency_offsets"].items()
if v is not None
k: v for k, v in params["latency_offsets"].items() if v is not None
}
if not params["latency_offsets"]:
del params["latency_offsets"]
return ", ".join(
[f"{param}={value}" for param, value in params.items()]
)
return ", ".join([f"{param}={value}" for param, value in params.items()])
# No need to redefined I/Os
io_ops = [*sfg._input_operations, *sfg._output_operations]
io_ops = [*sfg.input_operations, *sfg.output_operations]
result += "\n# Inputs:\n"
for input_op in sfg._input_operations:
for input_op in sfg.input_operations:
result += f"{input_op.graph_id} = Input({kwarg_unpacker(input_op)})\n"
result += "\n# Outputs:\n"
for output_op in sfg._output_operations:
result += (
f"{output_op.graph_id} = Output({kwarg_unpacker(output_op)})\n"
)
for output_op in sfg.output_operations:
result += f"{output_op.graph_id} = Output({kwarg_unpacker(output_op)})\n"
result += "\n# Operations:\n"
for op in sfg.split():
if op in io_ops:
for operation in sfg.split():
if operation in io_ops:
continue
if isinstance(op, SFG):
if isinstance(operation, SFG):
counter += 1
result = sfg_to_python(op, counter) + result
result = sfg_to_python(operation, counter) + result
continue
result += (
f"{op.graph_id} = {op.__class__.__name__}({kwarg_unpacker(op)})\n"
f"{operation.graph_id} ="
f" {operation.__class__.__name__}({kwarg_unpacker(operation)})\n"
)
result += "\n# Signals:\n"
# Keep track of already existing connections to avoid adding duplicates
connections = []
for op in sfg.split():
for out in op.outputs:
for operation in sfg.split():
for out in operation.outputs:
for signal in out.signals:
destination = cast(InputPort, signal.destination)
dest_op = destination.operation
connection = (
f"\nSignal(source={op.graph_id}."
f"output({op.outputs.index(signal.source)}),"
f"Signal(source={operation.graph_id}."
f"output({operation.outputs.index(signal.source)}),"
f" destination={dest_op.graph_id}."
f"input({dest_op.inputs.index(destination)}))"
f"input({dest_op.inputs.index(destination)}))\n"
)
if connection in connections:
continue
......@@ -119,20 +125,14 @@ def sfg_to_python(
connections.append(connection)
inputs = "[" + ", ".join(op.graph_id for op in sfg.input_operations) + "]"
outputs = (
"[" + ", ".join(op.graph_id for op in sfg.output_operations) + "]"
)
sfg_name = (
sfg.name if sfg.name else f"sfg{counter}" if counter > 0 else "sfg"
)
sfg_name_var = sfg_name.replace(" ", "_")
outputs = "[" + ", ".join(op.graph_id for op in sfg.output_operations) + "]"
sfg_name = sfg.name if sfg.name else f"sfg{counter}" if counter > 0 else "sfg"
sfg_name_var = sfg_name.replace(" ", "_").replace("-", "_")
result += "\n# Signal flow graph:\n"
result += (
f"\n{sfg_name_var} = SFG(inputs={inputs}, outputs={outputs},"
f" name='{sfg_name}')\n"
)
result += (
"\n# SFG Properties:\n" + "prop = {'name':" + f"{sfg_name_var}" + "}"
f"{sfg_name_var} = SFG(inputs={inputs}, outputs={outputs}, name='{sfg_name}')\n"
)
result += "\n# SFG Properties:\n" + "prop = {'name':" + f"{sfg_name_var}" + "}\n"
if suffix is not None:
result += "\n" + suffix + "\n"
......@@ -142,15 +142,15 @@ def sfg_to_python(
def python_to_sfg(path: str) -> Tuple[SFG, Dict[str, Tuple[int, int]]]:
"""
Given a serialized file try to deserialize it and load it to the library.
Given a serialized file, try to deserialize it and load it to the library.
Parameters
==========
path : str
Path to file to read and deserialize.
"""
with open(path) as f:
code = compile(f.read(), path, "exec")
with open(path) as file:
code = compile(file.read(), path, "exec")
exec(code, globals(), locals())
return (
......@@ -159,3 +159,29 @@ def python_to_sfg(path: str) -> Tuple[SFG, Dict[str, Tuple[int, int]]]:
else [v for k, v in locals().items() if isinstance(v, SFG)][0],
locals()["positions"] if "positions" in locals() else {},
)
def schedule_to_python(schedule: Schedule) -> str:
"""
Given a schedule structure try to serialize it for saving to a file.
Parameters
==========
schedule : Schedule
The schedule to serialize.
"""
if not isinstance(schedule, Schedule):
raise TypeError("A Schedule must be provided")
sfg_name = (
schedule.sfg.name.replace(" ", "_").replace("-", "_")
if schedule.sfg.name
else "sfg"
)
result = "\n# Schedule:\n"
nonzerolaps = {gid: val for gid, val in dict(schedule.laps).items() if val}
result += (
f"{sfg_name}_schedule = Schedule({sfg_name}, {schedule.schedule_time},"
f" {schedule.cyclic}, 'provided', {schedule.start_times},"
f" {nonzerolaps})\n"
)
return sfg_to_python(schedule.sfg, schedule=True) + result
......@@ -31,7 +31,7 @@ from b_asic._preferences import (
from b_asic.graph_component import GraphID
from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort
from b_asic.process import MemoryVariable, Process
from b_asic.process import MemoryVariable
from b_asic.resources import ProcessCollection
from b_asic.signal_flow_graph import SFG
from b_asic.special_operations import Delay, Input, Output
......@@ -55,8 +55,15 @@ class Schedule:
algorithm.
cyclic : bool, default: False
If the schedule is cyclic.
scheduling_algorithm : {'ASAP'}, optional
scheduling_algorithm : {'ASAP', 'provided'}, optional
The scheduling algorithm to use. Currently, only "ASAP" is supported.
If 'provided', use provided *start_times* and *laps* dictionaries.
start_times : dict, optional
Dictionary with GraphIDs as keys and start times as values.
Used when *scheduling_algorithm* is 'provided'.
laps : dict, optional
Dictionary with GraphIDs as keys and laps as values.
Used when *scheduling_algorithm* is 'provided'.
"""
_sfg: SFG
......@@ -72,8 +79,14 @@ class Schedule:
schedule_time: Optional[int] = None,
cyclic: bool = False,
scheduling_algorithm: str = "ASAP",
start_times: Optional[Dict[GraphID, int]] = None,
laps: Optional[Dict[GraphID, int]] = None,
):
"""Construct a Schedule from an SFG."""
if not isinstance(sfg, SFG):
raise TypeError("An SFG must be provided")
self._original_sfg = sfg() # Make a copy
self._sfg = sfg
self._start_times = {}
self._laps = defaultdict(lambda: 0)
......@@ -81,6 +94,14 @@ class Schedule:
self._y_locations = defaultdict(lambda: None)
if scheduling_algorithm == "ASAP":
self._schedule_asap()
elif scheduling_algorithm == "provided":
if start_times is None:
raise ValueError("Must provide start_times when using 'provided'")
if laps is None:
raise ValueError("Must provide laps when using 'provided'")
self._start_times = start_times
self._laps.update(laps)
self._remove_delays_no_laps()
else:
raise NotImplementedError(
f"No algorithm with name: {scheduling_algorithm} defined."
......@@ -107,8 +128,8 @@ class Schedule:
"""Return the current maximum end time among all operations."""
max_end_time = 0
for graph_id, op_start_time in self._start_times.items():
op = cast(Operation, self._sfg.find_by_id(graph_id))
for outport in op.outputs:
operation = cast(Operation, self._sfg.find_by_id(graph_id))
for outport in operation.outputs:
max_end_time = max(
max_end_time,
op_start_time + cast(int, outport.latency_offset),
......@@ -149,8 +170,8 @@ class Schedule:
) -> Dict["OutputPort", Dict["Signal", int]]:
ret = {}
start_time = self._start_times[graph_id]
op = cast(Operation, self._sfg.find_by_id(graph_id))
for output_port in op.outputs:
operation = cast(Operation, self._sfg.find_by_id(graph_id))
for output_port in operation.outputs:
output_slacks = {}
available_time = start_time + cast(int, output_port.latency_offset)
......@@ -200,8 +221,8 @@ class Schedule:
def _backward_slacks(self, graph_id: GraphID) -> Dict[InputPort, Dict[Signal, int]]:
ret = {}
start_time = self._start_times[graph_id]
op = cast(Operation, self._sfg.find_by_id(graph_id))
for input_port in op.inputs:
operation = cast(Operation, self._sfg.find_by_id(graph_id))
for input_port in operation.inputs:
input_slacks = {}
usage_time = start_time + cast(int, input_port.latency_offset)
......@@ -244,6 +265,7 @@ class Schedule:
return self.backward_slack(graph_id), self.forward_slack(graph_id)
def print_slacks(self) -> None:
"""Print the slack times for all operations in the schedule."""
raise NotImplementedError
def set_schedule_time(self, time: int) -> "Schedule":
......@@ -269,22 +291,29 @@ class Schedule:
@property
def sfg(self) -> SFG:
return self._sfg
"""The SFG of the current schedule."""
return self._original_sfg
@property
def start_times(self) -> Dict[GraphID, int]:
"""The start times of the operations in the current schedule."""
return self._start_times
@property
def laps(self) -> Dict[GraphID, int]:
"""
The number of laps for the start times of the operations in the current schedule.
"""
return self._laps
@property
def schedule_time(self) -> int:
"""The schedule time of the current schedule."""
return self._schedule_time
@property
def cyclic(self) -> bool:
"""If the current schedule is cyclic."""
return self._cyclic
def increase_time_resolution(self, factor: int) -> "Schedule":
......@@ -314,8 +343,11 @@ class Schedule:
ret = [self._schedule_time, *self._start_times.values()]
# Loop over operations
for graph_id in self._start_times:
op = cast(Operation, self._sfg.find_by_id(graph_id))
ret += [cast(int, op.execution_time), *op.latency_offsets.values()]
operation = cast(Operation, self._sfg.find_by_id(graph_id))
ret += [
cast(int, operation.execution_time),
*operation.latency_offsets.values(),
]
# Remove not set values (None)
ret = [v for v in ret if v is not None]
return ret
......@@ -360,6 +392,75 @@ class Schedule:
self._schedule_time = self._schedule_time // factor
return self
def move_y_location(
self, graph_id: GraphID, new_y: int, insert: bool = False
) -> None:
"""
Move operation in y-direction and remove any empty rows.
Parameters
----------
graph_id : GraphID
The GraphID of the operation to move.
new_y : int
The new y-position of the operation.
insert : bool, optional
If True, all operations on that y-position will be moved one position.
The default is False.
"""
if insert:
for gid in self._y_locations:
if self.get_y_location(gid) >= new_y:
self.set_y_location(gid, self.get_y_location(gid) + 1)
self.set_y_location(graph_id, new_y)
used_locations = {*self._y_locations.values()}
possible_locations = set(range(max(used_locations) + 1))
if not possible_locations - used_locations:
return
remapping = {}
offset = 0
for loc in possible_locations:
if loc in used_locations:
remapping[loc] = loc - offset
else:
offset += 1
for gid, y_location in self._y_locations.items():
self._y_locations[gid] = remapping[self._y_locations[gid]]
def get_y_location(self, graph_id: GraphID) -> int:
"""
Get the y-position of the Operation with GraphID *graph_id*.
Parameters
----------
graph_id : GraphID
The GraphID of the operation.
Returns
-------
int
The y-position of the operation.
"""
return self._y_locations[graph_id]
def set_y_location(self, graph_id: GraphID, y_location: int) -> None:
"""
Set the y-position of the Operation with GraphID *graph_id* to *y_location*.
Parameters
----------
graph_id : GraphID
The GraphID of the operation to move.
y_location : int
The new y-position of the operation.
"""
self._y_locations[graph_id] = y_location
def move_operation(self, graph_id: GraphID, time: int) -> "Schedule":
"""
Move an operation in the schedule.
......@@ -463,7 +564,16 @@ class Schedule:
self._start_times[graph_id] = new_start
return self
def _remove_delays_no_laps(self) -> None:
"""Remove delay elements without updating laps. Used when loading schedule."""
delay_list = self._sfg.find_by_type_name(Delay.type_name())
while delay_list:
delay_op = cast(Delay, delay_list[0])
self._sfg = cast(SFG, self._sfg.remove_operation(delay_op.graph_id))
delay_list = self._sfg.find_by_type_name(Delay.type_name())
def _remove_delays(self) -> None:
"""Remove delay elements and update laps. Used after scheduling algorithm."""
delay_list = self._sfg.find_by_type_name(Delay.type_name())
while delay_list:
delay_op = cast(Delay, delay_list[0])
......@@ -477,35 +587,35 @@ class Schedule:
def _schedule_asap(self) -> None:
"""Schedule the operations using as-soon-as-possible scheduling."""
pl = self._sfg.get_precedence_list()
precedence_list = self._sfg.get_precedence_list()
if len(pl) < 2:
if len(precedence_list) < 2:
print("Empty signal flow graph cannot be scheduled.")
return
non_schedulable_ops = set()
for outport in pl[0]:
op = outport.operation
if op.type_name() not in [Delay.type_name()]:
if op.graph_id not in self._start_times:
for outport in precedence_list[0]:
operation = outport.operation
if operation.type_name() not in [Delay.type_name()]:
if operation.graph_id not in self._start_times:
# Set start time of all operations in the first iter to 0
self._start_times[op.graph_id] = 0
self._start_times[operation.graph_id] = 0
else:
non_schedulable_ops.add(op.graph_id)
non_schedulable_ops.add(operation.graph_id)
for outport in pl[1]:
op = outport.operation
if op.graph_id not in self._start_times:
for outport in precedence_list[1]:
operation = outport.operation
if operation.graph_id not in self._start_times:
# Set start time of all operations in the first iter to 0
self._start_times[op.graph_id] = 0
self._start_times[operation.graph_id] = 0
for outports in pl[2:]:
for outports in precedence_list[2:]:
for outport in outports:
op = outport.operation
if op.graph_id not in self._start_times:
operation = outport.operation
if operation.graph_id not in self._start_times:
# Schedule the operation if it does not have a start time yet.
op_start_time = 0
for inport in op.inputs:
for inport in operation.inputs:
if len(inport.signals) != 1:
raise ValueError(
"Error in scheduling, dangling input port detected."
......@@ -545,7 +655,7 @@ class Schedule:
op_start_time_from_in = source_end_time - inport.latency_offset
op_start_time = max(op_start_time, op_start_time_from_in)
self._start_times[op.graph_id] = op_start_time
self._start_times[operation.graph_id] = op_start_time
for output in self._sfg.find_by_type_name(Output.type_name()):
output = cast(Output, output)
source_port = cast(OutputPort, output.inputs[0].signals[0].source)
......@@ -650,7 +760,7 @@ class Schedule:
line_cache.append(start)
elif end[0] == start[0]:
p = Path(
path = Path(
[
start,
[start[0] + SPLINE_OFFSET, start[1]],
......@@ -670,16 +780,16 @@ class Schedule:
Path.CURVE4,
],
)
pp = PathPatch(
p,
path_patch = PathPatch(
path,
fc='none',
ec=_SIGNAL_COLOR,
lw=SIGNAL_LINEWIDTH,
zorder=10,
)
ax.add_patch(pp)
ax.add_patch(path_patch)
else:
p = Path(
path = Path(
[
start,
[(start[0] + end[0]) / 2, start[1]],
......@@ -688,14 +798,14 @@ class Schedule:
],
[Path.MOVETO, Path.CURVE4, Path.CURVE4, Path.CURVE4],
)
pp = PathPatch(
p,
path_patch = PathPatch(
path,
fc='none',
ec=_SIGNAL_COLOR,
lw=SIGNAL_LINEWIDTH,
zorder=10,
)
ax.add_patch(pp)
ax.add_patch(path_patch)
def _draw_offset_arrow(start, end, start_offset, end_offset, name="", laps=0):
"""Draw an arrow from *start* to *end*, but with an offset."""
......@@ -712,12 +822,12 @@ class Schedule:
ax.grid()
for graph_id, op_start_time in self._start_times.items():
y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
op = cast(Operation, self._sfg.find_by_id(graph_id))
operation = cast(Operation, self._sfg.find_by_id(graph_id))
# Rewrite to make better use of NumPy
(
latency_coordinates,
execution_time_coordinates,
) = op.get_plot_coordinates()
) = operation.get_plot_coordinates()
_x, _y = zip(*latency_coordinates)
x = np.array(_x)
y = np.array(_y)
......@@ -737,11 +847,11 @@ class Schedule:
yticklabels.append(cast(Operation, self._sfg.find_by_id(graph_id)).name)
for graph_id, op_start_time in self._start_times.items():
op = cast(Operation, self._sfg.find_by_id(graph_id))
out_coordinates = op.get_output_coordinates()
operation = cast(Operation, self._sfg.find_by_id(graph_id))
out_coordinates = operation.get_output_coordinates()
source_y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
for output_port in op.outputs:
for output_port in operation.outputs:
for output_signal in output_port.signals:
destination = cast(InputPort, output_signal.destination)
destination_op = destination.operation
......@@ -786,7 +896,7 @@ class Schedule:
def _reset_y_locations(self) -> None:
"""Reset all the y-locations in the schedule to None"""
self._y_locations = self._y_locations = defaultdict(lambda: None)
self._y_locations = defaultdict(lambda: None)
def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
"""
......@@ -839,7 +949,7 @@ class Schedule:
"""
fig, ax = plt.subplots()
self._plot_schedule(ax)
f = io.StringIO()
fig.savefig(f, format="svg")
buffer = io.StringIO()
fig.savefig(buffer, format="svg")
return f.getvalue()
return buffer.getvalue()
......@@ -482,9 +482,13 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self._graph._signals.schedule_time_changed.connect(
self.info_table_update_schedule
)
self._graph._signals.redraw_all.connect(self._redraw_all)
self.info_table_fill_schedule(self._schedule)
self.update_statusbar(self.tr("Schedule loaded successfully"))
def _redraw_all(self) -> None:
self._graph._redraw_all()
def update_statusbar(self, msg: str) -> None:
"""
Write *msg* to the statusbar with temporarily policy.
......
......@@ -219,8 +219,6 @@ class OperationItem(QGraphicsItemGroup):
# component item
self._set_background(OPERATION_LATENCY_INACTIVE) # used by component filling
inputs, outputs = self._operation.get_io_coordinates()
def create_ports(io_coordinates, prefix):
for i, (x, y) in enumerate(io_coordinates):
pos = QPointF(x, y * self._height)
......@@ -235,8 +233,8 @@ class OperationItem(QGraphicsItemGroup):
new_port.setPos(port_pos.x(), port_pos.y())
self._port_items.append(new_port)
create_ports(inputs, "in")
create_ports(outputs, "out")
create_ports(self._operation.get_input_coordinates(), "in")
create_ports(self._operation.get_output_coordinates(), "out")
# op-id/label
self._label_item = QGraphicsSimpleTextItem(self._operation.graph_id)
......
......@@ -6,7 +6,7 @@ B-ASIC Scheduler-GUI Graphics Scheduler Event Module.
Contains the scheduler_ui SchedulerEvent class containing event filters and
handlers for SchedulerItem objects.
"""
import math
from typing import List, Optional, overload
# QGraphics and QPainter imports
......@@ -44,12 +44,14 @@ class SchedulerEvent: # PyQt5
component_selected = Signal(str)
schedule_time_changed = Signal()
component_moved = Signal(str)
redraw_all = Signal()
_axes: Optional[AxesItem]
_current_pos: QPointF
_delta_time: int
_signals: Signals # PyQt5
_schedule: Schedule
_old_op_position: int = -1
def __init__(self, parent: Optional[QGraphicsItem] = None): # PyQt5
super().__init__(parent=parent)
......@@ -195,20 +197,23 @@ class SchedulerEvent: # PyQt5
def update_pos(operation_item, dx, dy):
pos_x = operation_item.x() + dx
pos_y = operation_item.y() + dy * (OPERATION_GAP + OPERATION_HEIGHT)
if self.is_component_valid_pos(operation_item, pos_x):
pos_y = operation_item.y() + dy * (OPERATION_GAP + OPERATION_HEIGHT)
operation_item.setX(pos_x)
operation_item.setY(pos_y)
self._current_pos.setX(self._current_pos.x() + dx)
self._current_pos.setY(self._current_pos.y() + dy)
self._redraw_lines(operation_item)
self._schedule._y_locations[operation_item.operation.graph_id] += dy
gid = operation_item.operation.graph_id
self._schedule.set_y_location(
gid, dy + self._schedule.get_y_location(gid)
)
item: OperationItem = self.scene().mouseGrabberItem()
delta_x = (item.mapToParent(event.pos()) - self._current_pos).x()
delta_y = (item.mapToParent(event.pos()) - self._current_pos).y()
delta_y_steps = round(delta_y / (OPERATION_GAP + OPERATION_HEIGHT))
delta_y_steps = round(2 * delta_y / (OPERATION_GAP + OPERATION_HEIGHT)) / 2
if delta_x > 0.505:
update_pos(item, 1, delta_y_steps)
elif delta_x < -0.505:
......@@ -224,6 +229,7 @@ class SchedulerEvent: # PyQt5
allows the item to receive future move, release and double-click events.
"""
item: OperationItem = self.scene().mouseGrabberItem()
self._old_op_position = self._schedule.get_y_location(item.operation.graph_id)
self._signals.component_selected.emit(item.graph_id)
self._current_pos = item.mapToParent(event.pos())
self.set_item_active(item)
......@@ -242,10 +248,20 @@ class SchedulerEvent: # PyQt5
if pos_x > self._schedule.schedule_time:
pos_x = pos_x % self._schedule.schedule_time
redraw = True
pos_y = self._schedule.get_y_location(item.operation.graph_id)
# Check move in y-direction
if pos_y != self._old_op_position:
self._schedule.move_y_location(
item.operation.graph_id,
math.ceil(pos_y),
(pos_y % 1) != 0,
)
self._signals.redraw_all.emit()
# Operation has been moved in x-direction
if redraw:
item.setX(pos_x)
self._redraw_lines(item)
self._signals.component_moved.emit(item.graph_id)
self._signals.component_moved.emit(item.graph_id)
def operation_mouseDoubleClickEvent(self, event: QGraphicsSceneMouseEvent) -> None:
...
......
......@@ -224,15 +224,18 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
def _redraw_from_start(self) -> None:
self.schedule._reset_y_locations()
for graph_id in {
k: v
for k, v in sorted(
self.schedule.start_times.items(), key=lambda item: item[1]
)
}:
for graph_id in dict(
sorted(self.schedule.start_times.items(), key=lambda item: item[1])
):
self._set_position(graph_id)
self._redraw_all_lines()
def _redraw_all(self) -> None:
for graph_id in self._operation_items:
self._set_position(graph_id)
self._redraw_all_lines()
self._update_axes()
def _update_axes(self, build=False) -> None:
# build axes
schedule_time = self.schedule.schedule_time
......@@ -254,7 +257,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5
"""Make a new graph out of the stored attributes."""
# build components
for graph_id in self.schedule.start_times.keys():
operation = cast(Operation, self.schedule.sfg.find_by_id(graph_id))
operation = cast(Operation, self.schedule._sfg.find_by_id(graph_id))
component = OperationItem(operation, height=OPERATION_HEIGHT, parent=self)
self._operation_items[graph_id] = component
self._set_position(graph_id)
......
......@@ -124,7 +124,7 @@ def direct_form_fir(
name: Optional[str] = None,
mult_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
add_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
):
) -> SFG:
r"""
Generate a signal flow graph of a direct form FIR filter.
......@@ -148,7 +148,7 @@ def direct_form_fir(
The Output to connect the SFG to. If not provided, one will be generated.
name : Name, optional
The name of the SFG. If None, "WDF allpass section".
The name of the SFG. If None, "Direct-form FIR filter".
mult_properties : dictionary, optional
Properties passed to :class:`~b_asic.core_operations.ConstantMultiplication`.
......@@ -205,7 +205,7 @@ def transposed_direct_form_fir(
name: Optional[str] = None,
mult_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
add_properties: Optional[Union[Dict[str, int], Dict[str, Dict[str, int]]]] = None,
):
) -> SFG:
r"""
Generate a signal flow graph of a transposed direct form FIR filter.
......@@ -229,7 +229,7 @@ def transposed_direct_form_fir(
The Output to connect the SFG to. If not provided, one will be generated.
name : Name, optional
The name of the SFG. If None, "WDF allpass section".
The name of the SFG. If None, "Transposed direct-form FIR filter".
mult_properties : dictionary, optional
Properties passed to :class:`~b_asic.core_operations.ConstantMultiplication`.
......
......@@ -769,7 +769,12 @@ class SFG(AbstractOperation):
for port in ports:
port_string = port.name
if port.operation.output_count > 1:
sub.node(port_string)
sub.node(
port_string,
shape='rectangle',
height="0.1",
width="0.1",
)
else:
sub.node(
port_string,
......@@ -1378,7 +1383,8 @@ class SFG(AbstractOperation):
dg.format = fmt
dg.view()
def critical_path(self):
def critical_path_time(self) -> int:
"""Return the time of the critical path."""
# Import here needed to avoid circular imports
from b_asic.schedule import Schedule
......@@ -1534,3 +1540,11 @@ class SFG(AbstractOperation):
assert len(ids) == len(set(ids))
return SFG(inputs=all_inputs, outputs=all_outputs)
@property
def is_linear(self) -> bool:
return all(op.is_linear for op in self.split())
@property
def is_constant(self) -> bool:
return all(output.is_constant for output in self._output_operations)
"""
B-ASIC signal generators
B-ASIC Signal Generator Module.
These can be used as input to Simulation to algorithmically provide signal values.
Especially, all classes defined here will act as a callable which accepts an integer
......@@ -22,7 +22,7 @@ class SignalGenerator:
"""
Base class for signal generators.
Handles operator overloading and defined the ``__call__`` method that should
Handles operator overloading and defines the ``__call__`` method that should
be overridden.
"""
......@@ -94,7 +94,7 @@ class Impulse(SignalGenerator):
def __call__(self, time: int) -> complex:
return 1 if time == self._delay else 0
def __repr__(self):
def __repr__(self) -> str:
return f"Impulse({self._delay})" if self._delay else "Impulse()"
......@@ -114,7 +114,7 @@ class Step(SignalGenerator):
def __call__(self, time: int) -> complex:
return 1 if time >= self._delay else 0
def __repr__(self):
def __repr__(self) -> str:
return f"Step({self._delay})" if self._delay else "Step()"
......@@ -134,7 +134,7 @@ class Constant(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._constant
def __str__(self):
def __str__(self) -> str:
return f"{self._constant}"
......@@ -157,7 +157,7 @@ class ZeroPad(SignalGenerator):
return self._data[time]
return 0.0
def __repr__(self):
def __repr__(self) -> str:
return f"ZeroPad({self._data})"
......@@ -181,7 +181,7 @@ class Sinusoid(SignalGenerator):
def __call__(self, time: int) -> complex:
return sin(pi * (self._frequency * time + self._phase))
def __repr__(self):
def __repr__(self) -> str:
return (
f"Sinusoid({self._frequency}, {self._phase})"
if self._phase
......@@ -216,7 +216,7 @@ class Gaussian(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._rng.normal(self._loc, self._scale)
def __repr__(self):
def __repr__(self) -> str:
ret_list = []
if self._seed is not None:
ret_list.append(f"seed={self._seed}")
......@@ -256,7 +256,7 @@ class Uniform(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._rng.uniform(self._low, self._high)
def __repr__(self):
def __repr__(self) -> str:
ret_list = []
if self._seed is not None:
ret_list.append(f"seed={self._seed}")
......@@ -268,6 +268,34 @@ class Uniform(SignalGenerator):
return f"Uniform({args})"
class Delay(SignalGenerator):
"""
Signal generator that delays the value of another signal generator.
This can used to easily delay a sequence during simulation.
.. note:: Although the purpose is to delay, it is also possible to look ahead by
providing a negative delay.
Parameters
----------
generator : SignalGenerator
The signal generator to delay the output of.
delay : int, default: 1
The number of time units to delay the generated signal.
"""
def __init__(self, generator: SignalGenerator, delay: int = 1) -> None:
self._generator = generator
self._delay = delay
def __call__(self, time: int) -> complex:
return self._generator(time - self._delay)
def __repr__(self) -> str:
return f"Delay({self._generator!r}, {self._delay})"
class _AddGenerator(SignalGenerator):
"""
Signal generator that adds two signals.
......@@ -280,7 +308,7 @@ class _AddGenerator(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._a(time) + self._b(time)
def __repr__(self):
def __repr__(self) -> str:
return f"{self._a} + {self._b}"
......@@ -296,7 +324,7 @@ class _SubGenerator(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._a(time) - self._b(time)
def __repr__(self):
def __repr__(self) -> str:
return f"{self._a} - {self._b}"
......@@ -312,7 +340,7 @@ class _MulGenerator(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._a(time) * self._b(time)
def __repr__(self):
def __repr__(self) -> str:
a = (
f"({self._a})"
if isinstance(self._a, (_AddGenerator, _SubGenerator))
......@@ -338,7 +366,7 @@ class _DivGenerator(SignalGenerator):
def __call__(self, time: int) -> complex:
return self._a(time) / self._b(time)
def __repr__(self):
def __repr__(self) -> str:
a = (
f"({self._a})"
if isinstance(self._a, (_AddGenerator, _SubGenerator))
......
......@@ -59,6 +59,9 @@ class Simulation:
input_providers: Optional[Sequence[Optional[InputProvider]]] = None,
):
"""Construct a Simulation of an SFG."""
if not isinstance(sfg, SFG):
raise TypeError("An SFG must be provided")
# Copy the SFG to make sure it's not modified from the outside.
self._sfg = sfg()
self._results = defaultdict(list)
......@@ -214,3 +217,10 @@ class Simulation:
Clear all current state of the simulation, except for the results and iteration.
"""
self._delays.clear()
def show(self) -> None:
"""Show the simulation results."""
# import here to avoid cyclic imports
from b_asic.gui_utils.plot_window import start_simulation_dialog
start_simulation_dialog(self.results, self._sfg.name)
......@@ -44,6 +44,10 @@ class Input(AbstractOperation):
def evaluate(self):
return self.param("value")
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
@property
def value(self) -> Num:
"""Get the current value of this input."""
......@@ -56,9 +60,7 @@ class Input(AbstractOperation):
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
# Doc-string inherited
return (
(
......@@ -87,6 +89,14 @@ class Input(AbstractOperation):
# doc-string inherited
return ((0, 0.5),)
@property
def is_constant(self) -> bool:
return False
@property
def is_linear(self) -> bool:
return True
class Output(AbstractOperation):
"""
......@@ -122,9 +132,7 @@ class Output(AbstractOperation):
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
# Doc-string inherited
return (
((0, 0), (0, 1), (0.25, 1), (0.5, 0.5), (0.25, 0), (0, 0)),
......@@ -139,6 +147,14 @@ class Output(AbstractOperation):
# doc-string inherited
return tuple()
@property
def latency(self) -> int:
return self.latency_offsets["in0"]
@property
def is_linear(self) -> bool:
return True
class Delay(AbstractOperation):
"""
......@@ -174,9 +190,7 @@ class Delay(AbstractOperation):
self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Optional[Num]:
if delays is not None:
return delays.get(
self.key(index, prefix), self.param("initial_value")
)
return delays.get(self.key(index, prefix), self.param("initial_value"))
return self.param("initial_value")
def evaluate_output(
......@@ -190,9 +204,7 @@ class Delay(AbstractOperation):
truncate: bool = True,
) -> Num:
if index != 0:
raise IndexError(
f"Output index out of range (expected 0-0, got {index})"
)
raise IndexError(f"Output index out of range (expected 0-0, got {index})")
if len(input_values) != 1:
raise ValueError(
"Wrong number of inputs supplied to SFG for evaluation"
......@@ -221,3 +233,7 @@ class Delay(AbstractOperation):
def initial_value(self, value: Num) -> None:
"""Set the initial value of this delay."""
self.set_param("initial_value", value)
@property
def is_linear(self) -> bool:
return True