Skip to content
Snippets Groups Projects
Commit 93f159ce authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Cleanup code and add typing

parent a8b79691
No related branches found
No related tags found
1 merge request!234Cleanup code and add typing
Pipeline #90367 passed
......@@ -9,7 +9,7 @@ import logging
import os
import sys
from pprint import pprint
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from qtpy.QtCore import QFileInfo, QSize, Qt
from qtpy.QtGui import QCursor, QIcon, QKeySequence, QPainter
......@@ -47,6 +47,7 @@ from b_asic.gui_utils.plot_window import PlotWindow
from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort
from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.signal import Signal
from b_asic.signal_flow_graph import SFG
# from b_asic import FastSimulation
......@@ -158,13 +159,12 @@ class MainWindow(QMainWindow):
self.toolbar.addAction("Clear workspace", self.clear_workspace)
def resizeEvent(self, event) -> None:
self.ui.operation_box.setGeometry(
10, 10, self.ui.operation_box.width(), self.height()
)
ui_width = self.ui.operation_box.width()
self.ui.operation_box.setGeometry(10, 10, ui_width, self.height())
self.graphic_view.setGeometry(
self.ui.operation_box.width() + 20,
ui_width + 20,
60,
self.width() - self.ui.operation_box.width() - 20,
self.width() - ui_width - 20,
self.height() - 30,
)
super().resizeEvent(event)
......@@ -214,20 +214,20 @@ class MainWindow(QMainWindow):
self.logger.info("Saved SFG to path: " + str(module))
def save_work(self, event=None):
def save_work(self, event=None) -> None:
self.sfg_widget = SelectSFGWindow(self)
self.sfg_widget.show()
# Wait for input to dialog.
self.sfg_widget.ok.connect(self._save_work)
def load_work(self, event=None):
def load_work(self, event=None) -> None:
module, accepted = QFileDialog().getOpenFileName()
if not accepted:
return
self._load_from_file(module)
def _load_from_file(self, module):
def _load_from_file(self, module) -> None:
self.logger.info("Loading SFG from path: " + str(module))
try:
sfg, positions = python_to_sfg(module)
......@@ -252,7 +252,7 @@ class MainWindow(QMainWindow):
self._load_sfg(sfg, positions)
self.logger.info("Loaded SFG from path: " + str(module))
def _load_sfg(self, sfg, positions=None):
def _load_sfg(self, sfg, positions=None) -> None:
if positions is None:
positions = {}
......@@ -299,11 +299,11 @@ class MainWindow(QMainWindow):
self.sfg_dict[sfg.name] = sfg
self.update()
def exit_app(self):
def exit_app(self) -> None:
self.logger.info("Exiting the application.")
QApplication.quit()
def clear_workspace(self):
def clear_workspace(self) -> None:
self.logger.info("Clearing workspace from operations and SFGs.")
self.pressed_operations.clear()
self.pressed_ports.clear()
......@@ -341,7 +341,7 @@ class MainWindow(QMainWindow):
sfg = SFG(inputs=inputs, outputs=outputs, name=name)
self.logger.info("Created SFG with name: %s from selected operations." % name)
def check_equality(signal, signal_2):
def check_equality(signal: Signal, signal_2: Signal) -> bool:
if not (
signal.source.operation.type_name()
== signal_2.source.operation.type_name()
......@@ -440,11 +440,11 @@ class MainWindow(QMainWindow):
self.sfg_dict[sfg.name] = sfg
def _show_precedence_graph(self, event=None) -> None:
self.dialog = ShowPCWindow(self)
self.dialog.add_sfg_to_dialog()
self.dialog.show()
self._precedence_graph_dialog = ShowPCWindow(self)
self._precedence_graph_dialog.add_sfg_to_dialog()
self._precedence_graph_dialog.show()
def get_operations_from_namespace(self, namespace) -> None:
def get_operations_from_namespace(self, namespace) -> List[str]:
self.logger.info(
"Fetching operations from namespace: " + str(namespace.__name__)
)
......@@ -667,11 +667,11 @@ class MainWindow(QMainWindow):
self.update()
def paintEvent(self, event):
def paintEvent(self, event) -> None:
for signal in self.signalPortDict.keys():
signal.moveLine()
def _select_operations(self):
def _select_operations(self) -> None:
selected = [button.widget() for button in self.scene.selectedItems()]
for button in selected:
button._toggle_button(pressed=False)
......@@ -682,8 +682,8 @@ class MainWindow(QMainWindow):
self.pressed_operations = selected
def _simulate_sfg(self):
for sfg, properties in self.dialog.properties.items():
def _simulate_sfg(self) -> None:
for sfg, properties in self._simulation_dialog.properties.items():
self.logger.info("Simulating SFG with name: %s" % str(sfg.name))
simulation = FastSimulation(sfg, input_providers=properties["input_values"])
l_result = simulation.run_for(
......@@ -697,36 +697,32 @@ class MainWindow(QMainWindow):
if properties["show_plot"]:
self.logger.info("Opening plot for SFG with name: " + str(sfg.name))
self.logger.info(
"To save the plot press 'Ctrl+S' when the plot is focused."
)
# self.plot = Plot(simulation, sfg, self)
self.plot = PlotWindow(simulation.results)
self.plot.show()
self._plot = PlotWindow(simulation.results, sfg_name=sfg.name)
self._plot.show()
def simulate_sfg(self, event=None):
self.dialog = SimulateSFGWindow(self)
def simulate_sfg(self, event=None) -> None:
self._simulation_dialog = SimulateSFGWindow(self)
for _, sfg in self.sfg_dict.items():
self.dialog.add_sfg_to_dialog(sfg)
self._simulation_dialog.add_sfg_to_dialog(sfg)
self.dialog.show()
self._simulation_dialog.show()
# Wait for input to dialog.
# Kinda buggy because of the separate window in the same thread.
self.dialog.simulate.connect(self._simulate_sfg)
self._simulation_dialog.simulate.connect(self._simulate_sfg)
def display_faq_page(self, event=None):
self.faq_page = FaqWindow(self)
self.faq_page.scroll_area.show()
def display_faq_page(self, event=None) -> None:
self._faq_page = FaqWindow(self)
self._faq_page.scroll_area.show()
def display_about_page(self, event=None):
self.about_page = AboutWindow(self)
self.about_page.show()
def display_about_page(self, event=None) -> None:
self._about_page = AboutWindow(self)
self._about_page.show()
def display_keybinds_page(self, event=None):
self.keybinds_page = KeybindsWindow(self)
self.keybinds_page.show()
def display_keybinds_page(self, event=None) -> None:
self._keybinds_page = KeybindsWindow(self)
self._keybinds_page.show()
def start_gui():
......
......@@ -32,6 +32,19 @@ class SignalGeneratorInput(QGridLayout):
"""Return the SignalGenerator based on the graphical input."""
raise NotImplementedError
def _parse_number(self, string, _type, name, default):
string = string.strip()
try:
if not string:
return default
return _type(string)
except ValueError:
self._logger.warning(
f"Cannot parse {name}: {string} not a {_type.__name__}, setting to"
f" {default}"
)
return default
class DelayInput(SignalGeneratorInput):
"""
......@@ -82,6 +95,7 @@ class ZeroPadInput(SignalGeneratorInput):
self.input_label = QLabel("Input")
self.addWidget(self.input_label, 0, 0)
self.input_sequence = QLineEdit()
self.input_sequence.setPlaceholderText("0.1, -0.2, 0.7")
self.addWidget(self.input_sequence, 0, 1)
def get_generator(self) -> SignalGenerator:
......@@ -91,14 +105,11 @@ class ZeroPadInput(SignalGeneratorInput):
try:
if not val:
val = 0
val = complex(val)
except ValueError:
self._logger.warning(f"Skipping value: {val}, not a digit.")
continue
input_values.append(val)
return ZeroPad(input_values)
......@@ -138,34 +149,20 @@ class SinusoidInput(SignalGeneratorInput):
self.frequency_label = QLabel("Frequency")
self.addWidget(self.frequency_label, 0, 0)
self.frequency_input = QLineEdit()
self.frequency_input.setText("0.1")
self.addWidget(self.frequency_input, 0, 1)
self.phase_label = QLabel("Phase")
self.addWidget(self.phase_label, 1, 0)
self.phase_input = QLineEdit()
self.phase_input.setText("0.0")
self.addWidget(self.phase_input, 1, 1)
def get_generator(self) -> SignalGenerator:
frequency = self.frequency_input.text().strip()
try:
if not frequency:
frequency = 0.1
frequency = float(frequency)
except ValueError:
self._logger.warning(f"Cannot parse frequency: {frequency} not a number.")
frequency = 0.1
phase = self.phase_input.text().strip()
try:
if not phase:
phase = 0
phase = float(phase)
except ValueError:
self._logger.warning(f"Cannot parse phase: {phase} not a number.")
phase = 0
frequency = self._parse_number(
self.frequency_input.text(), float, "Frequency", 0.1
)
phase = self._parse_number(self.phase_input.text(), float, "Phase", 0.0)
return Sinusoid(frequency, phase)
......@@ -196,26 +193,10 @@ class GaussianInput(SignalGeneratorInput):
self.addWidget(self.seed_spin_box, 2, 1)
def get_generator(self) -> SignalGenerator:
scale = self.scale_input.text().strip()
try:
if not scale:
scale = 1
scale = float(scale)
except ValueError:
self._logger.warning(f"Cannot parse scale: {scale} not a number.")
scale = 1
loc = self.loc_input.text().strip()
try:
if not loc:
loc = 0
loc = float(loc)
except ValueError:
self._logger.warning(f"Cannot parse loc: {loc} not a number.")
loc = 0
scale = self._parse_number(
self.scale_input.text(), float, "Standard deviation", 1.0
)
loc = self._parse_number(self.loc_input.text(), float, "Average value", 0.0)
return Gaussian(self.seed_spin_box.value(), loc, scale)
......@@ -246,26 +227,8 @@ class UniformInput(SignalGeneratorInput):
self.addWidget(self.seed_spin_box, 2, 1)
def get_generator(self) -> SignalGenerator:
low = self.low_input.text().strip()
try:
if not low:
low = -1.0
low = float(low)
except ValueError:
self._logger.warning(f"Cannot parse low: {low} not a number.")
low = -1.0
high = self.high_input.text().strip()
try:
if not high:
high = 1.0
high = float(high)
except ValueError:
self._logger.warning(f"Cannot parse high: {high} not a number.")
high = 1.0
low = self._parse_number(self.low_input.text(), float, "Lower bound", -1.0)
high = self._parse_number(self.high_input.text(), float, "Upper bound", 1.0)
return Uniform(self.seed_spin_box.value(), low, high)
......@@ -284,16 +247,9 @@ class ConstantInput(SignalGeneratorInput):
self.addWidget(self.constant_input, 0, 1)
def get_generator(self) -> SignalGenerator:
constant = self.constant_input.text().strip()
try:
if not constant:
constant = 1.0
constant = complex(constant)
except ValueError:
self._logger.warning(f"Cannot parse constant: {constant} not a number.")
constant = 0.0
constant = self._parse_number(
self.constant_input.text(), complex, "Constant", 1.0
)
return Constant(constant)
......
"""
B-ASIC window to simulate an SFG.
"""
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from qtpy.QtCore import Qt, Signal
from qtpy.QtGui import QKeySequence
from qtpy.QtWidgets import (
QCheckBox,
QComboBox,
QDialog,
QFileDialog,
QFormLayout,
QFrame,
QGridLayout,
QHBoxLayout,
QLabel,
QLayout,
QLineEdit,
QPushButton,
QShortcut,
QSizePolicy,
QSpinBox,
QVBoxLayout,
)
from b_asic.GUI.signal_generator_input import _GENERATOR_MAPPING
from b_asic.signal_generator import FromFile
class SimulateSFGWindow(QDialog):
......@@ -50,7 +40,7 @@ class SimulateSFGWindow(QDialog):
self.input_grid = QGridLayout()
self.input_files = {}
def add_sfg_to_dialog(self, sfg):
def add_sfg_to_dialog(self, sfg) -> None:
sfg_layout = QVBoxLayout()
options_layout = QFormLayout()
......@@ -112,7 +102,7 @@ class SimulateSFGWindow(QDialog):
self.sfg_to_layout[sfg] = sfg_layout
self.dialog_layout.addLayout(sfg_layout)
def change_input_format(self, i, text):
def change_input_format(self, i: int, text: str) -> None:
grid = self.input_grid.itemAtPosition(i, 2)
if grid:
for j in reversed(range(grid.count())):
......@@ -127,13 +117,11 @@ class SimulateSFGWindow(QDialog):
if text in _GENERATOR_MAPPING:
param_grid = _GENERATOR_MAPPING[text](self._window.logger)
else:
raise Exception("Input selection is not implemented")
raise ValueError("Input selection is not implemented")
self.input_grid.addLayout(param_grid, i, 2)
return
def save_properties(self):
def save_properties(self) -> None:
for sfg, _properties in self.input_fields.items():
ic_value = self.input_fields[sfg]["iteration_count"].value()
if ic_value == 0:
......@@ -148,7 +136,7 @@ class SimulateSFGWindow(QDialog):
if in_format in _GENERATOR_MAPPING:
tmp2 = in_param.get_generator()
else:
raise Exception("Input selection is not implemented")
raise ValueError("Input selection is not implemented")
input_values.append(tmp2)
......@@ -166,45 +154,3 @@ class SimulateSFGWindow(QDialog):
self.accept()
self.simulate.emit()
class Plot(FigureCanvas):
def __init__(
self, simulation, sfg, window, parent=None, width=5, height=4, dpi=100
):
self.simulation = simulation
self.sfg = sfg
self.dpi = dpi
self._window = window
fig = Figure(figsize=(width, height), dpi=dpi)
fig.suptitle(sfg.name, fontsize=20)
self.axes = fig.add_subplot(111)
FigureCanvas.__init__(self, fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.save_figure = QShortcut(QKeySequence("Ctrl+S"), self)
self.save_figure.activated.connect(self._save_plot_figure)
self._plot_values_sfg()
def _save_plot_figure(self):
self._window.logger.info(f"Saving plot of figure: {self.sfg.name}.")
file_choices = "PNG (*.png)|*.png"
path, ext = QFileDialog.getSaveFileName(self, "Save file", "", file_choices)
path = path.encode("utf-8")
if not path[-4:] == file_choices[-4:].encode("utf-8"):
path += file_choices[-4:].encode("utf-8")
if path:
self.print_figure(path.decode(), dpi=self.dpi)
self._window.logger.info(f"Saved plot: {self.sfg.name} to path: {path}.")
def _plot_values_sfg(self):
x_axis = list(range(len(self.simulation.results["0"])))
for _output in range(self.sfg.output_count):
y_axis = self.simulation.results[str(_output)]
self.axes.plot(x_axis, y_axis)
......@@ -141,9 +141,9 @@ def test_help_dialogs(qtbot):
widget.display_about_page()
widget.display_keybinds_page()
qtbot.wait(100)
widget.faq_page.close()
widget.about_page.close()
widget.keybinds_page.close()
widget._faq_page.close()
widget._about_page.close()
widget._keybinds_page.close()
widget.exit_app()
......@@ -159,7 +159,7 @@ def test_simulate(qtbot, datadir):
qtbot.wait(100)
# widget.dialog.save_properties()
# qtbot.wait(100)
widget.dialog.close()
widget._simulation_dialog.close()
widget.exit_app()
......
......@@ -27,8 +27,8 @@ def test_MemoryVariables(secondorder_iir_schedule):
pc = secondorder_iir_schedule.get_memory_variables()
mem_vars = pc.collection
pattern = re.compile(
"MemoryVariable\\(3, <b_asic.port.OutputPort object at 0x[a-f0-9]+>,"
" {<b_asic.port.InputPort object at 0x[a-f0-9]+>: 4}, 'cmul1.0'\\)"
"MemoryVariable\\(3, <b_asic.port.OutputPort object at 0x[a-fA-F0-9]+>,"
" {<b_asic.port.InputPort object at 0x[a-fA-F0-9]+>: 4}, 'cmul1.0'\\)"
)
mem_var = [m for m in mem_vars if m.name == 'cmul1.0'][0]
assert pattern.match(repr(mem_var))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment