diff --git a/README.md b/README.md index 20f28bee3cfd71fc1b03100b9f9f3632f8a8c284..b2972f30828f19038e5efe612831f989b8f5ffd5 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ The following packages are required in order to build the library: * setuptools * pybind11 * numpy - * pyside2/pyqt5 + * pyside2 To build a binary distribution, the following additional packages are required: * Python: diff --git a/b_asic/GUI/arrow.py b/b_asic/GUI/arrow.py index 14b788a42f2e11b6a780c541f9783d4fefb233a6..c975c728c8c4a36a1aa208e4f5bfd22107bc9a26 100644 --- a/b_asic/GUI/arrow.py +++ b/b_asic/GUI/arrow.py @@ -1,9 +1,9 @@ -from PyQt5.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ +from PySide2.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ QStatusBar, QMenuBar, QLineEdit, QPushButton, QSlider, QScrollArea, QVBoxLayout,\ QHBoxLayout, QDockWidget, QToolBar, QMenu, QLayout, QSizePolicy, QListWidget, QListWidgetItem,\ QGraphicsLineItem, QGraphicsWidget -from PyQt5.QtCore import Qt, QSize, QLineF, QPoint, QRectF -from PyQt5.QtGui import QIcon, QFont, QPainter, QPen +from PySide2.QtCore import Qt, QSize, QLineF, QPoint, QRectF +from PySide2.QtGui import QIcon, QFont, QPainter, QPen from b_asic import Signal diff --git a/b_asic/GUI/drag_button.py b/b_asic/GUI/drag_button.py index a001122e033f9dbe282a19b76ef2850113b9f1c3..5b0a734a09c41cd64848dd994ecaaef2d2fa120f 100644 --- a/b_asic/GUI/drag_button.py +++ b/b_asic/GUI/drag_button.py @@ -5,31 +5,50 @@ This class creates a dragbutton which can be clicked, dragged and dropped. import os.path -from PyQt5.QtWidgets import QPushButton -from PyQt5.QtCore import Qt, QSize -from PyQt5.QtGui import QIcon +from properties_window import PropertiesWindow + +from PySide2.QtWidgets import QPushButton, QMenu, QAction +from PySide2.QtCore import Qt, QSize, Signal +from PySide2.QtGui import QIcon from utils import decorate_class, handle_error @decorate_class(handle_error) class DragButton(QPushButton): - def __init__(self, name, operation, operation_path_name, window, parent = None): + connectionRequested = Signal(QPushButton) + moved = Signal() + def __init__(self, name, operation, operation_path_name, is_show_name, window, parent = None): self.name = name + self.is_show_name = is_show_name self._window = window self.operation = operation self.operation_path_name = operation_path_name self.clicked = 0 self.pressed = False + self._mouse_press_pos = None + self._mouse_move_pos = None super(DragButton, self).__init__(self._window) + def contextMenuEvent(self, event): + menu = QMenu() + properties = QAction("Properties") + menu.addAction(properties) + properties.triggered.connect(self.show_properties_window) + menu.exec_(self.cursor().pos()) + + def show_properties_window(self, event): + self.properties_window = PropertiesWindow(self, self._window) + self.properties_window.show() + + def add_label(self, label): + self.label = label + def mousePressEvent(self, event): - self._mouse_press_pos = None - self._mouse_move_pos = None if event.button() == Qt.LeftButton: - self._mouse_press_pos = event.globalPos() - self._mouse_move_pos = event.globalPos() + self._mouse_press_pos = event.pos() + self._mouse_move_pos = event.pos() for signal in self._window.signalList: signal.update() @@ -43,7 +62,7 @@ class DragButton(QPushButton): color: black }""") path_to_image = os.path.join('operation_icons', self.operation_path_name + '_grey.png') self.setIcon(QIcon(path_to_image)) - self.setIconSize(QSize(50, 50)) + self.setIconSize(QSize(55, 55)) self._window.pressed_operations.append(self) elif self.clicked == 2: @@ -55,27 +74,21 @@ class DragButton(QPushButton): color: black}""") path_to_image = os.path.join('operation_icons', self.operation_path_name + '.png') self.setIcon(QIcon(path_to_image)) - self.setIconSize(QSize(50, 50)) + self.setIconSize(QSize(55, 55)) self._window.pressed_operations.remove(self) super(DragButton, self).mousePressEvent(event) def mouseMoveEvent(self, event): if event.buttons() == Qt.LeftButton: - cur_pos = self.mapToGlobal(self.pos()) - global_pos = event.globalPos() - diff = global_pos - self._mouse_move_pos - new_pos = self.mapFromGlobal(cur_pos + diff) - self.move(new_pos) - - self._mouse_move_pos = global_pos + self.move(self.mapToParent(event.pos() - self._mouse_press_pos)) self._window.update() super(DragButton, self).mouseMoveEvent(event) def mouseReleaseEvent(self, event): if self._mouse_press_pos is not None: - moved = event.globalPos() - self._mouse_press_pos + moved = event.pos() - self._mouse_press_pos if moved.manhattanLength() > 3: event.ignore() return diff --git a/b_asic/GUI/gui_interface.py b/b_asic/GUI/gui_interface.py index a6d9a3b8f436c615201dac8ba7537272573cae00..2d42d808f24d6c55850ae122979f1498f2191a2f 100644 --- a/b_asic/GUI/gui_interface.py +++ b/b_asic/GUI/gui_interface.py @@ -7,7 +7,7 @@ # WARNING! All changes made in this file will be lost! -from PyQt5 import QtCore, QtGui, QtWidgets +from PySide2 import QtCore, QtGui, QtWidgets class Ui_main_window(object): diff --git a/b_asic/GUI/improved_main_window.py b/b_asic/GUI/improved_main_window.py deleted file mode 100644 index fa4e06c4b07f6ea71a2a8b84179fcdc6ad5b1a3b..0000000000000000000000000000000000000000 --- a/b_asic/GUI/improved_main_window.py +++ /dev/null @@ -1,232 +0,0 @@ -"""@package docstring -B-ASIC GUI Module. -This python file is the main window of the GUI for B-ASIC. -""" - -from os import getcwd, path -import sys -import math - -from drag_button import DragButton -from gui_interface import Ui_main_window -from arrow import Arrow -from port_button import PortButton - -from b_asic import Operation -import b_asic.core_operations as c_oper -import b_asic.special_operations as s_oper -from utils import decorate_class, handle_error -from b_asic import SFG -from b_asic import InputPort, OutputPort - -from numpy import linspace - -from PyQt5.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ -QStatusBar, QMenuBar, QLineEdit, QPushButton, QSlider, QScrollArea, QVBoxLayout,\ -QHBoxLayout, QDockWidget, QToolBar, QMenu, QLayout, QSizePolicy, QListWidget,\ -QListWidgetItem, QGraphicsView, QGraphicsScene, QShortcut, QToolTip -from PyQt5.QtCore import Qt, QSize -from PyQt5.QtGui import QIcon, QFont, QPainter, QPen, QBrush, QKeySequence - - -@decorate_class(handle_error) -class MainWindow(QMainWindow): - def __init__(self): - super(MainWindow, self).__init__() - self.ui = Ui_main_window() - self.ui.setupUi(self) - self.setWindowTitle(" ") - self.setWindowIcon(QIcon('small_logo.png')) - self.scene = None - self._operations_from_name = dict() - self.zoom = 1 - self.sfg_name_i = 0 - self.operationList = [] - self.signalList = [] - self.pressed_operations = [] - self.portList = [] - self.pressed_ports = [] - self.sfg_list = [] - self.source = None - self._window = self - - self.init_ui() - self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) - self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) - - self.shortcut_core = QShortcut(QKeySequence("Ctrl+R"), self.ui.operation_box) - self.shortcut_core.activated.connect(self._refresh_operations_list_from_namespace) - - def init_ui(self): - self.ui.core_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) - self.ui.special_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) - self.ui.exit_menu.triggered.connect(self.exit_app) - self.create_toolbar_view() - self.create_graphics_view() - - def create_graphics_view(self): - self.scene = QGraphicsScene() - self.graphic_view = QGraphicsView(self.scene, self) - self.graphic_view.setRenderHint(QPainter.Antialiasing) - self.graphic_view.setGeometry(250, 80, 600, 520) - self.graphic_view.setDragMode(QGraphicsView.ScrollHandDrag) - - def create_toolbar_view(self): - self.toolbar = self.addToolBar("Toolbar") - self.toolbar.addAction("Create SFG", self.create_SFG_from_toolbar) - - def wheelEvent(self, event): - old_zoom = self.zoom - self.zoom += event.angleDelta().y()/2500 - self.graphic_view.scale(self.zoom, self.zoom) - self.zoom = old_zoom - - def exit_app(self, checked): - QApplication.quit() - - def create_SFG_from_toolbar(self): - inputs = [] - outputs = [] - for op in self.pressed_operations: - if isinstance(op.operation, s_oper.Input): - inputs.append(op.operation) - elif isinstance(op.operation, s_oper.Output): - outputs.append(op.operation) - - self.sfg_name_i += 1 - sfg = SFG(inputs=inputs, outputs=outputs, name="sfg" + str(self.sfg_name_i)) - for op in self.pressed_operations: - op.setToolTip(sfg.name) - self.sfg_list.append(sfg) - - - def _determine_port_distance(self, length, ports): - """Determine the distance between each port on the side of an operation. - The method returns the distance that each port should have from 0. - """ - return [length / 2] if ports == 1 else linspace(0, length, ports) - - def _create_port(self, operation, port, output_port=True): - text = ">" if output_port else "<" - button = PortButton(text, operation, port, self) - button.setStyleSheet("background-color: white") - button.connectionRequested.connect(self.connectButton) - return button - - def add_ports(self, operation): - _output_ports_dist = self._determine_port_distance(50 - 15, operation.operation.output_count) - _input_ports_dist = self._determine_port_distance(50 - 15, operation.operation.input_count) - - for i, dist in enumerate(_input_ports_dist): - port = self._create_port(operation, operation.operation.input(i)) - port.move(0, dist) - port.show() - - for i, dist in enumerate(_output_ports_dist): - port = self._create_port(operation, operation.operation.output(i)) - port.move(35, dist) - port.show() - - def get_operations_from_namespace(self, namespace): - return [comp for comp in dir(namespace) if hasattr(getattr(namespace, comp), "type_name")] - - def add_operations_from_namespace(self, namespace, _list): - for attr_name in self.get_operations_from_namespace(namespace): - attr = getattr(namespace, attr_name) - try: - attr.type_name() - item = QListWidgetItem(attr_name) - _list.addItem(item) - self._operations_from_name[attr_name] = attr - except NotImplementedError: - pass - - def _create_operation(self, item): - try: - attr_oper = self._operations_from_name[item.text()]() - attr_button = DragButton(attr_oper.graph_id, attr_oper, attr_oper.type_name().lower(), self) - attr_button.move(250, 100) - attr_button.setFixedSize(50, 50) - attr_button.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - self.add_ports(attr_button) - - icon_path = path.join("operation_icons", f"{attr_oper.type_name().lower()}.png") - if not path.exists(icon_path): - icon_path = path.join("operation_icons", f"custom_operation.png") - attr_button.setIcon(QIcon(icon_path)) - attr_button.setIconSize(QSize(50, 50)) - - - attr_button.setToolTip("No sfg") - attr_button.setStyleSheet(""" QToolTip { background-color: white; - color: black }""") - - attr_button.setParent(None) - self.scene.addWidget(attr_button) - self.operationList.append(attr_button) - except Exception as e: - print("Unexpected error occured: ", e) - - def _refresh_operations_list_from_namespace(self): - self.ui.core_operations_list.clear() - self.ui.special_operations_list.clear() - - self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) - self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) - - def print_input_port_1(self): - print("Input port 1") - - def print_input_port_2(self): - print("Input port 2") - - def print_output_port_1(self): - print("Output port 1") - - def print_output_port_2(self): - print("Output port 2") - - def on_list_widget_item_clicked(self, item): - self._create_operation(item) - - def keyPressEvent(self, event): - pressed_buttons = [] - for op in self.operationList: - if op.pressed: - pressed_buttons.append(op) - if event.key() == Qt.Key_Delete: - for pressed_op in pressed_buttons: - self.operationList.remove(pressed_op) - pressed_op.remove() - super().keyPressEvent(event) - - def connectButton(self, button): - if len(self.pressed_ports) < 2: - return - for i in range(len(self.pressed_ports) - 1): - if isinstance(self.pressed_ports[i].port, OutputPort) and \ - isinstance(self.pressed_ports[i+1].port, InputPort): - line = Arrow(self.pressed_ports[i], self.pressed_ports[i + 1], self) - self.scene.addItem(line) - self.signalList.append(line) - - self.deselectPorts() - self.update() - - def paintEvent(self, event): - for signal in self.signalList: - signal.moveLine() - - def deselectPorts(self): - for port in self.pressed_ports: - port.setStyleSheet("background-color: white") - port.pressed = False - port.clicked = 0 - self.pressed_ports.clear() - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = MainWindow() - window.show() - sys.exit(app.exec_()) diff --git a/b_asic/GUI/main_window.py b/b_asic/GUI/main_window.py index b65a118438b64ee9b6540bac5177b77cc9c2cf44..ffa8448a7b2ed93b2b7cb79d3c609ac2fed34303 100644 --- a/b_asic/GUI/main_window.py +++ b/b_asic/GUI/main_window.py @@ -1,414 +1,244 @@ """@package docstring B-ASIC GUI Module. -This python file is an example of how a GUI can be implemented -using buttons and textboxes. +This python file is the main window of the GUI for B-ASIC. """ +from os import getcwd, path import sys -from PyQt5.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ +from drag_button import DragButton +from gui_interface import Ui_main_window +from arrow import Arrow +from port_button import PortButton + +from b_asic import Operation +import b_asic.core_operations as c_oper +import b_asic.special_operations as s_oper +from utils import decorate_class, handle_error +from b_asic import SFG +from b_asic import InputPort, OutputPort + +from numpy import linspace + +from PySide2.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ QStatusBar, QMenuBar, QLineEdit, QPushButton, QSlider, QScrollArea, QVBoxLayout,\ -QHBoxLayout, QDockWidget, QToolBar, QMenu -from PyQt5.QtCore import Qt, QSize, pyqtSlot -from PyQt5.QtGui import QIcon, QFont, QPainter, QPen, QColor - -from b_asic.core_operations import Addition - - -class DragButton(QPushButton): - def __init__(self, name, window, parent = None): - self.name = name - self.__menu = None - self.__window = window - self.counter = 0 - self.clicked = 0 - self.pressed = False - print("Constructor" + self.name) - super(DragButton, self).__init__(self.__window) - - self.__window.setContextMenuPolicy(Qt.CustomContextMenu) - self.__window.customContextMenuRequested.connect(self.create_menu) - - - @pyqtSlot(QAction) - def actionClicked(self, action): - print("Triggern "+ self.name, self.__menu.name) - #self.__window.check_for_remove_op(self.name) - - #def show_context_menu(self, point): - # show context menu - - - def create_menu(self, point): - self.counter += 1 - # create context menu - popMenu = MyMenu('Menu' + str(self.counter)) - popMenu.addAction(QAction('Add a signal', self)) - popMenu.addAction(QAction('Remove a signal', self)) - #action = QAction('Remove operation', self) - popMenu.addAction('Remove operation', lambda:self.removeAction(self)) - popMenu.addSeparator() - popMenu.addAction(QAction('Remove all signals', self)) - self.__window.menuList.append(popMenu) - #self.__window.actionList.append(action) - self.__menu = popMenu - self.pressed = False - self.__menu.exec_(self.__window.sender().mapToGlobal(point)) - self.__menu.triggered.connect(self.actionClicked) - - - def removeAction(self, op): - print(op.__menu.name, op.name) - op.remove() - - """This class is made to create a draggable button""" - - def mousePressEvent(self, event): - self._mouse_press_pos = None - self._mouse_move_pos = None - self.clicked += 1 - if self.clicked == 1: - self.pressed = True - self.setStyleSheet("background-color: grey; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - elif self.clicked == 2: - self.clicked = 0 - self.presseed = False - self.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - - if event.button() == Qt.LeftButton: - self._mouse_press_pos = event.globalPos() - self._mouse_move_pos = event.globalPos() - - super(DragButton, self).mousePressEvent(event) - - def mouseMoveEvent(self, event): - if event.buttons() == Qt.LeftButton: - cur_pos = self.mapToGlobal(self.pos()) - global_pos = event.globalPos() - diff = global_pos - self._mouse_move_pos - new_pos = self.mapFromGlobal(cur_pos + diff) - self.move(new_pos) - self.pressed = False - - self._mouse_move_pos = global_pos - - super(DragButton, self).mouseMoveEvent(event) - - def mouseReleaseEvent(self, event): - if self._mouse_press_pos is not None: - moved = event.globalPos() - self._mouse_press_pos - if moved.manhattanLength() > 3: - event.ignore() - return - - super(DragButton, self).mouseReleaseEvent(event) - - def remove(self): - self.deleteLater() - -class SubWindow(QWidget): - """Creates a sub window """ - def create_window(self, window_width, window_height): - """Creates a window - """ - parent = None - super(SubWindow, self).__init__(parent) - self.setWindowFlags(Qt.WindowStaysOnTopHint) - self.resize(window_width, window_height) +QHBoxLayout, QDockWidget, QToolBar, QMenu, QLayout, QSizePolicy, QListWidget,\ +QListWidgetItem, QGraphicsView, QGraphicsScene, QShortcut, QGraphicsTextItem,\ +QGraphicsProxyWidget +from PySide2.QtCore import Qt, QSize +from PySide2.QtGui import QIcon, QFont, QPainter, QPen, QBrush, QKeySequence -class MyMenu(QMenu): - def __init__(self, name, parent = None): - self.name = name - super(MyMenu, self).__init__() - +MIN_WIDTH_SCENE = 600 +MIN_HEIGHT_SCENE = 520 +@decorate_class(handle_error) class MainWindow(QMainWindow): - """Main window for the program""" - # pylint: disable=too-many-instance-attributes - # Eight is reasonable in this case. - def __init__(self, *args, **kwargs): - super(MainWindow, self).__init__(*args, **kwargs) - self.init_ui() - self.counter = 0 - self.operations = [] - self.menuList = [] - self.actionList = [] - - def init_ui(self): + def __init__(self): + super(MainWindow, self).__init__() + self.ui = Ui_main_window() + self.ui.setupUi(self) self.setWindowTitle(" ") self.setWindowIcon(QIcon('small_logo.png')) - self.create_operation_menu() - self.create_menu_bar() - self.setStatusBar(QStatusBar(self)) - - def create_operation_menu(self): - self.operation_box = QDockWidget("Operation Box", self) - self.operation_box.setAllowedAreas(Qt.LeftDockWidgetArea) - self.test = QToolBar(self) - self.operation_list = QMenuBar(self) - self.test.addWidget(self.operation_list) - self.test.setOrientation(Qt.Vertical) - self.operation_list.setStyleSheet("background-color:rgb(222,222,222); vertical") - basic_operations = self.operation_list.addMenu('Basic operations') - special_operations = self.operation_list.addMenu('Special operations') - - self.addition_menu_item = QAction('&Addition', self) - self.addition_menu_item.setStatusTip("Add addition operation to workspace") - self.addition_menu_item.triggered.connect(self.create_addition_operation) - basic_operations.addAction(self.addition_menu_item) - - self.subtraction_menu_item = QAction('&Subtraction', self) - self.subtraction_menu_item.setStatusTip("Add subtraction operation to workspace") - self.subtraction_menu_item.triggered.connect(self.create_subtraction_operation) - basic_operations.addAction(self.subtraction_menu_item) - - self.multiplication_menu_item = QAction('&Multiplication', self) - self.multiplication_menu_item.setStatusTip("Add multiplication operation to workspace") - self.multiplication_menu_item.triggered.connect(self.create_multiplication_operation) - basic_operations.addAction(self.multiplication_menu_item) - - self.division_menu_item = QAction('&Division', self) - self.division_menu_item.setStatusTip("Add division operation to workspace") - #self.division_menu_item.triggered.connect(self.create_division_operation) - basic_operations.addAction(self.division_menu_item) - - self.constant_menu_item = QAction('&Constant', self) - self.constant_menu_item.setStatusTip("Add constant operation to workspace") - #self.constant_menu_item.triggered.connect(self.create_constant_operation) - basic_operations.addAction(self.constant_menu_item) - - self.square_root_menu_item = QAction('&Square root', self) - self.square_root_menu_item.setStatusTip("Add square root operation to workspace") - #self.square_root_menu_item.triggered.connect(self.create_square_root_operation) - basic_operations.addAction(self.square_root_menu_item) - - self.complex_conjugate_menu_item = QAction('&Complex conjugate', self) - self.complex_conjugate_menu_item.setStatusTip("Add complex conjugate operation to workspace") - #self.complex_conjugate_menu_item.triggered.connect(self.create_complex_conjugate_operation) - basic_operations.addAction(self.complex_conjugate_menu_item) - - self.max_menu_item = QAction('&Max', self) - self.max_menu_item.setStatusTip("Add max operation to workspace") - #self.max_menu_item.triggered.connect(self.create_max_operation) - basic_operations.addAction(self.max_menu_item) - - self.min_menu_item = QAction('&Min', self) - self.min_menu_item.setStatusTip("Add min operation to workspace") - #self.min_menu_item.triggered.connect(self.create_min_operation) - basic_operations.addAction(self.min_menu_item) - - self.absolute_menu_item = QAction('&Absolute', self) - self.absolute_menu_item.setStatusTip("Add absolute operation to workspace") - #self.absolute_menu_item.triggered.connect(self.create_absolute_operation) - basic_operations.addAction(self.absolute_menu_item) - - self.constant_addition_menu_item = QAction('&Constant addition', self) - self.constant_addition_menu_item.setStatusTip("Add constant addition operation to workspace") - #self.constant_addition_menu_item.triggered.connect(self.create_constant_addition_operation) - basic_operations.addAction(self.constant_addition_menu_item) - - self.constant_subtraction_menu_item = QAction('&Constant subtraction', self) - self.constant_subtraction_menu_item.setStatusTip("Add constant subtraction operation to workspace") - #self.constant_subtraction_menu_item.triggered.connect(self.create_constant_subtraction_operation) - basic_operations.addAction(self.constant_subtraction_menu_item) - - self.constant_multiplication_menu_item = QAction('&Constant multiplication', self) - self.constant_multiplication_menu_item.setStatusTip("Add constant multiplication operation to workspace") - #self.constant_multiplication_menu_item.triggered.connect(self.create_constant_multiplication_operation) - basic_operations.addAction(self.constant_multiplication_menu_item) - - self.constant_division_menu_item = QAction('&Constant division', self) - self.constant_division_menu_item.setStatusTip("Add constant division operation to workspace") - #self.constant_division_menu_item.triggered.connect(self.create_constant_division_operation) - basic_operations.addAction(self.constant_division_menu_item) - - self.butterfly_menu_item = QAction('&Butterfly', self) - self.butterfly_menu_item.setStatusTip("Add butterfly operation to workspace") - #self.butterfly_menu_item.triggered.connect(self.create_butterfly_operation) - basic_operations.addAction(self.butterfly_menu_item) - - self.operation_box.setWidget(self.operation_list) - self.operation_box.setMaximumSize(240, 400) - self.operation_box.setFeatures(QDockWidget.NoDockWidgetFeatures) - self.operation_box.setFixedSize(300, 500) - self.operation_box.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px") - self.addDockWidget(Qt.LeftDockWidgetArea, self.operation_box) - - def create_addition_operation(self): - self.counter += 1 - - # Create drag button - addition_operation = DragButton("OP" + str(self.counter), self) - addition_operation.move(250, 100) - addition_operation.setFixedSize(50, 50) - addition_operation.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - addition_operation.clicked.connect(self.create_sub_window) - #self.addition_operation.setIcon(QIcon("GUI'\'operation_icons'\'plus.png")) - addition_operation.setText("OP" + str(self.counter)) - addition_operation.setIconSize(QSize(50, 50)) - addition_operation.show() - self.operations.append(addition_operation) - - # set context menu policies - #self.addition_operation.setContextMenuPolicy(Qt.CustomContextMenu) - #self.addition_operation.customContextMenuRequested.connect(self.show_context_menu) - - #self.action.triggered.connect(lambda checked: self.remove(self.addition_operation.name)) - - def check_for_remove_op(self, name): - self.remove(name) - - - def remove(self, name): - for op in self.operations: - print(name, op.name) - if op.name == name: - self.operations.remove(op) - op.remove() - - def create_subtraction_operation(self): - self.subtraction_operation = DragButton("sub" + str(self.counter), self) - self.subtraction_operation.move(250, 100) - self.subtraction_operation.setFixedSize(50, 50) - self.subtraction_operation.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - self.subtraction_operation.setIcon(QIcon("GUI'\'operation_icons'\'minus.png")) - self.subtraction_operation.setIconSize(QSize(50, 50)) - self.subtraction_operation.clicked.connect(self.create_sub_window) - self.subtraction_operation.show() - - # set context menu policies - self.subtraction_operation.setContextMenuPolicy(Qt.CustomContextMenu) - self.subtraction_operation.customContextMenuRequested.connect(self.show_context_menu) - - # create context menu - self.button_context_menu = QMenu(self) - self.button_context_menu.addAction(QAction('Add a signal', self)) - self.button_context_menu.addAction(QAction('Remove a signal', self)) - self.button_context_menu.addSeparator() - self.button_context_menu.addAction(QAction('Remove all signals', self)) - - def create_multiplication_operation(self): - self.multiplication_operation = DragButton(self) - self.multiplication_operation.move(250, 100) - self.multiplication_operation.setFixedSize(50, 50) - self.multiplication_operation.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - self.multiplication_operation.clicked.connect(self.create_sub_window) - self.multiplication_operation.setIcon(QIcon(r"GUI\operation_icons\plus.png")) - self.multiplication_operation.setIconSize(QSize(50, 50)) - self.multiplication_operation.show() - - # set context menu policies - self.multiplication_operation.setContextMenuPolicy(Qt.CustomContextMenu) - self.multiplication_operation.customContextMenuRequested.connect(self.show_context_menu) - - # create context menu - self.button_context_menu = QMenu(self) - self.button_context_menu.addAction(QAction('Add a signal', self)) - self.button_context_menu.addAction(QAction('Remove a signal', self)) - self.button_context_menu.addSeparator() - self.button_context_menu.addAction(QAction('Remove all signals', self)) - - - def create_menu_bar(self): - # Menu buttons - load_button = QAction("Load", self) - save_button = QAction("Save", self) - - exit_button = QAction("Exit", self) - exit_button.setShortcut("Ctrl+Q") - exit_button.triggered.connect(self.exit_app) - - edit_button = QAction("Edit", self) - edit_button.setStatusTip("Open edit menu") - edit_button.triggered.connect(self.on_edit_button_click) - - view_button = QAction("View", self) - view_button.setStatusTip("Open view menu") - view_button.triggered.connect(self.on_view_button_click) - - menu_bar = QMenuBar() - menu_bar.setStyleSheet("background-color:rgb(222, 222, 222)") - self.setMenuBar(menu_bar) - - file_menu = menu_bar.addMenu("&File") - file_menu.addAction(save_button) - file_menu.addSeparator() - file_menu.addAction(exit_button) - - edit_menu = menu_bar.addMenu("&Edit") - edit_menu.addAction(edit_button) - - edit_menu.addSeparator() - - view_menu = menu_bar.addMenu("&View") - view_menu.addAction(view_button) - - - def create_sub_window(self): - """ Example of how to create a sub window - """ - self.sub_window = SubWindow() - self.sub_window.create_window(400, 300) - self.sub_window.setWindowTitle("Properties") - - self.sub_window.properties_label = QLabel(self.sub_window) - self.sub_window.properties_label.setText('Properties') - self.sub_window.properties_label.setFixedWidth(400) - self.sub_window.properties_label.setFont(QFont('SansSerif', 14, QFont.Bold)) - self.sub_window.properties_label.setAlignment(Qt.AlignCenter) + self.scene = None + self._operations_from_name = dict() + self.zoom = 1 + self.sfg_name_i = 0 + self.operationList = [] + self.signalList = [] + self.pressed_operations = [] + self.portList = [] + self.pressed_ports = [] + self.sfg_list = [] + self.source = None + self._window = self - self.sub_window.name_label = QLabel(self.sub_window) - self.sub_window.name_label.setText('Name:') - self.sub_window.name_label.move(20, 40) - - self.sub_window.name_line = QLineEdit(self.sub_window) - self.sub_window.name_line.setPlaceholderText("Write a name here") - self.sub_window.name_line.move(70, 40) - self.sub_window.name_line.resize(100, 20) - - self.sub_window.id_label = QLabel(self.sub_window) - self.sub_window.id_label.setText('Id:') - self.sub_window.id_label.move(20, 70) - - self.sub_window.id_line = QLineEdit(self.sub_window) - self.sub_window.id_line.setPlaceholderText("Write an id here") - self.sub_window.id_line.move(70, 70) - self.sub_window.id_line.resize(100, 20) + self.init_ui() + self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) + self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) - self.sub_window.show() + self.shortcut_core = QShortcut(QKeySequence("Ctrl+R"), self.ui.operation_box) + self.shortcut_core.activated.connect(self._refresh_operations_list_from_namespace) - def keyPressEvent(self, event): - for op in self.operations: - if event.key() == Qt.Key_Delete and op.pressed: - self.operations.remove(op) - op.remove() - - def on_file_button_click(self): - print("File") + self.move_button_index = 0 + self.is_show_names = True - def on_edit_button_click(self): - print("Edit") + self.check_show_names = QAction("Show operation names") + self.check_show_names.triggered.connect(self.view_operation_names) + self.check_show_names.setCheckable(True) + self.check_show_names.setChecked(1) + self.ui.view_menu.addAction(self.check_show_names) - def on_view_button_click(self): - print("View") + def init_ui(self): + self.ui.core_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) + self.ui.special_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) + self.ui.exit_menu.triggered.connect(self.exit_app) + self.create_toolbar_view() + self.create_graphics_view() + + def create_graphics_view(self): + self.scene = QGraphicsScene(self) + self.graphic_view = QGraphicsView(self.scene, self) + self.graphic_view.setRenderHint(QPainter.Antialiasing) + self.graphic_view.setGeometry(self.ui.operation_box.width(), 0, self.width(), self.height()) + self.graphic_view.setDragMode(QGraphicsView.ScrollHandDrag) + + def create_toolbar_view(self): + self.toolbar = self.addToolBar("Toolbar") + self.toolbar.addAction("Create SFG", self.create_SFG_from_toolbar) + + def resizeEvent(self, event): + self.ui.operation_box.setGeometry(10, 10, self.ui.operation_box.width(), self.height()) + self.graphic_view.setGeometry(self.ui.operation_box.width() + 20, 0, self.width() - self.ui.operation_box.width() - 20, self.height()) + super(MainWindow, self).resizeEvent(event) + + def wheelEvent(self, event): + if event.modifiers() == Qt.ControlModifier: + old_zoom = self.zoom + self.zoom += event.angleDelta().y()/2500 + self.graphic_view.scale(self.zoom, self.zoom) + self.zoom = old_zoom + + def view_operation_names(self, event): + if self.check_show_names.isChecked(): + self.is_show_names = True + else: + self.is_show_names = False + for operation in self.operationList: + operation.label.setOpacity(self.is_show_names) + operation.is_show_name = self.is_show_names def exit_app(self, checked): QApplication.quit() - def clicked(self): - print("Drag button clicked") - + def create_SFG_from_toolbar(self): + inputs = [] + outputs = [] + for op in self.pressed_operations: + if isinstance(op.operation, s_oper.Input): + inputs.append(op.operation) + elif isinstance(op.operation, s_oper.Output): + outputs.append(op.operation) + + self.sfg_name_i += 1 + sfg = SFG(inputs=inputs, outputs=outputs, name="sfg" + str(self.sfg_name_i)) + for op in self.pressed_operations: + op.setToolTip(sfg.name) + self.sfg_list.append(sfg) + + def _determine_port_distance(self, length, ports): + """Determine the distance between each port on the side of an operation. + The method returns the distance that each port should have from 0. + """ + return [length / 2] if ports == 1 else linspace(0, length, ports) + + def _create_port(self, operation, port, output_port=True): + text = ">" if output_port else "<" + button = PortButton(text, operation, port, self) + button.setStyleSheet("background-color: white") + button.connectionRequested.connect(self.connectButton) + return button + + def add_ports(self, operation): + _output_ports_dist = self._determine_port_distance(55 - 17, operation.operation.output_count) + _input_ports_dist = self._determine_port_distance(55 - 17, operation.operation.input_count) + + for i, dist in enumerate(_input_ports_dist): + port = self._create_port(operation, operation.operation.input(i)) + port.move(0, dist) + port.show() + + for i, dist in enumerate(_output_ports_dist): + port = self._create_port(operation, operation.operation.output(i)) + port.move(55 - 12, dist) + port.show() + + def get_operations_from_namespace(self, namespace): + return [comp for comp in dir(namespace) if hasattr(getattr(namespace, comp), "type_name")] + + def add_operations_from_namespace(self, namespace, _list): + for attr_name in self.get_operations_from_namespace(namespace): + attr = getattr(namespace, attr_name) + try: + attr.type_name() + item = QListWidgetItem(attr_name) + _list.addItem(item) + self._operations_from_name[attr_name] = attr + except NotImplementedError: + pass + + def _create_operation(self, item): + try: + attr_oper = self._operations_from_name[item.text()]() + attr_button = DragButton(attr_oper.graph_id, attr_oper, attr_oper.type_name().lower(), True, self) + attr_button.move(250, 100) + attr_button.setFixedSize(55, 55) + attr_button.setStyleSheet("background-color: white; border-style: solid;\ + border-color: black; border-width: 2px") + self.add_ports(attr_button) + + icon_path = path.join("operation_icons", f"{attr_oper.type_name().lower()}.png") + if not path.exists(icon_path): + icon_path = path.join("operation_icons", f"custom_operation.png") + attr_button.setIcon(QIcon(icon_path)) + attr_button.setIconSize(QSize(55, 55)) + attr_button.setToolTip("No sfg") + attr_button.setStyleSheet(""" QToolTip { background-color: white; + color: black }""") + attr_button.setParent(None) + attr_button_scene = self.scene.addWidget(attr_button) + attr_button_scene.moveBy(self.move_button_index * 100, 0) + self.move_button_index += 1 + operation_label = QGraphicsTextItem(attr_oper.type_name(), attr_button_scene) + if not self.is_show_names: + operation_label.setOpacity(0) + operation_label.setTransformOriginPoint(operation_label.boundingRect().center()) + operation_label.moveBy(10, -20) + attr_button.add_label(operation_label) + self.operationList.append(attr_button) + except Exception as e: + print("Unexpected error occured: ", e) + + def _refresh_operations_list_from_namespace(self): + self.ui.core_operations_list.clear() + self.ui.special_operations_list.clear() + + self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) + self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) + + def on_list_widget_item_clicked(self, item): + self._create_operation(item) + + def keyPressEvent(self, event): + pressed_operations = [] + for op in self.operationList: + if op.pressed: + pressed_operations.append(op) + if event.key() == Qt.Key_Delete: + for pressed_op in pressed_operations: + self.operationList.remove(pressed_op) + pressed_op.remove() + self.move_button_index -= 1 + super().keyPressEvent(event) + + def connectButton(self, button): + if len(self.pressed_ports) < 2: + return + for i in range(len(self.pressed_ports) - 1): + if isinstance(self.pressed_ports[i].port, OutputPort) and \ + isinstance(self.pressed_ports[i+1].port, InputPort): + line = Arrow(self.pressed_ports[i], self.pressed_ports[i + 1], self) + self.scene.addItem(line) + self.signalList.append(line) + + self.update() + + def paintEvent(self, event): + for signal in self.signalList: + signal.moveLine() if __name__ == "__main__": app = QApplication(sys.argv) window = MainWindow() - window.resize(960, 720) window.show() - app.exec_() + sys.exit(app.exec_()) diff --git a/b_asic/GUI/port_button.py b/b_asic/GUI/port_button.py index 3dc55764c630a2843618d21bff1f875e516c48d2..04b355663cd057b399c6e2c182f70aebeddffc50 100644 --- a/b_asic/GUI/port_button.py +++ b/b_asic/GUI/port_button.py @@ -1,12 +1,12 @@ import sys -from PyQt5.QtWidgets import QPushButton, QMenu -from PyQt5.QtCore import Qt, pyqtSignal +from PySide2.QtWidgets import QPushButton, QMenu +from PySide2.QtCore import Qt, Signal class PortButton(QPushButton): - connectionRequested = pyqtSignal(QPushButton) - moved = pyqtSignal() + connectionRequested = Signal(QPushButton) + moved = Signal() def __init__(self, name, operation, port, window, parent=None): self.pressed = False self.window = window @@ -21,7 +21,6 @@ class PortButton(QPushButton): menu.exec_(self.cursor().pos()) def mousePressEvent(self, event): - if event.button() == Qt.LeftButton: self.clicked += 1 if self.clicked == 1: @@ -30,11 +29,11 @@ class PortButton(QPushButton): self.window.pressed_ports.append(self) elif self.clicked == 2: self.setStyleSheet("background-color: white") - self.pressed = False + self.pressed = False self.clicked = 0 self.window.pressed_ports.remove(self) super(PortButton, self).mousePressEvent(event) - + def mouseReleaseEvent(self, event): super(PortButton, self).mouseReleaseEvent(event) diff --git a/b_asic/GUI/properties_window.py b/b_asic/GUI/properties_window.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9d566c36505668a3fd53d6745f298f7050c44b --- /dev/null +++ b/b_asic/GUI/properties_window.py @@ -0,0 +1,62 @@ +from PySide2.QtWidgets import QDialog, QLineEdit, QPushButton, QVBoxLayout, QHBoxLayout,\ +QLabel, QCheckBox +from PySide2.QtCore import Qt +from PySide2.QtGui import QIntValidator + +class PropertiesWindow(QDialog): + def __init__(self, operation, main_window): + super(PropertiesWindow, self).__init__() + self.operation = operation + self.main_window = main_window + self.setWindowFlags(Qt.WindowTitleHint | Qt.WindowCloseButtonHint) + self.setWindowTitle("Properties") + + self.name_layout = QHBoxLayout() + self.name_layout.setSpacing(50) + self.name_label = QLabel("Name:") + self.edit_name = QLineEdit(self.operation.operation_path_name) + self.name_layout.addWidget(self.name_label) + self.name_layout.addWidget(self.edit_name) + + self.vertical_layout = QVBoxLayout() + self.vertical_layout.addLayout(self.name_layout) + + if self.operation.operation_path_name == "c": + self.constant_layout = QHBoxLayout() + self.constant_layout.setSpacing(50) + self.constant_value = QLabel("Constant:") + self.edit_constant = QLineEdit(str(self.operation.operation.value)) + self.only_accept_int = QIntValidator() + self.edit_constant.setValidator(self.only_accept_int) + self.constant_layout.addWidget(self.constant_value) + self.constant_layout.addWidget(self.edit_constant) + self.vertical_layout.addLayout(self.constant_layout) + + self.show_name_layout = QHBoxLayout() + self.check_show_name = QCheckBox("Show name?") + if self.operation.is_show_name: + self.check_show_name.setChecked(1) + else: + self.check_show_name.setChecked(0) + self.check_show_name.setLayoutDirection(Qt.RightToLeft) + self.check_show_name.setStyleSheet("spacing: 170px") + self.show_name_layout.addWidget(self.check_show_name) + self.vertical_layout.addLayout(self.show_name_layout) + + self.ok = QPushButton("OK") + self.ok.clicked.connect(self.save_properties) + self.vertical_layout.addWidget(self.ok) + self.setLayout(self.vertical_layout) + + def save_properties(self): + self.operation.name = self.edit_name.text() + self.operation.label.setPlainText(self.operation.name) + if self.operation.operation_path_name == "c": + self.operation.operation.value = self.edit_constant.text() + if self.check_show_name.isChecked(): + self.operation.label.setOpacity(1) + self.operation.is_show_name = True + else: + self.operation.label.setOpacity(0) + self.operation.is_show_name = False + self.reject() \ No newline at end of file diff --git a/b_asic/GUI/utils.py b/b_asic/GUI/utils.py index 721496c7db7d0259d6335681f6a4c95c6c5930b6..4fba57ed96cb0073511125d341c2ee2ede4ad182 100644 --- a/b_asic/GUI/utils.py +++ b/b_asic/GUI/utils.py @@ -1,4 +1,4 @@ -from PyQt5.QtWidgets import QErrorMessage +from PySide2.QtWidgets import QErrorMessage from traceback import format_exc def handle_error(fn): diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 3e6cd78755f727d034723e2f02e707225cbe9611..ec7306c6f4c97b5c0377794e48524d09c7ed159b 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -240,3 +240,18 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + +class MAD(AbstractOperation): + """Multiply-and-add operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2]) + + @property + def type_name(self) -> TypeName: + return "mad" + + def evaluate(self, a, b, c): + return a * b + c diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 5efa8038e6c914d0e8f2023ebf6f05eb58664e5b..e08422a842a84d08dcab58ab03d7f581cb1bc664 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -105,6 +105,10 @@ class AbstractGraphComponent(GraphComponent): self._graph_id = "" self._parameters = {} + def __str__(self): + return f"id: {self.graph_id if self.graph_id else 'no_id'}, \tname: {self.name if self.name else 'no_name'}" + \ + "".join((f", \t{key}: {str(param)}" for key, param in self._parameters.items())) + @property def name(self) -> Name: return self._name diff --git a/b_asic/operation.py b/b_asic/operation.py index f8ac22e2a1d26e13365d0d742775de6f1f020057..02ba1aa50682448e931a0694d24e03e20eadd399 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -186,6 +186,13 @@ class Operation(GraphComponent, SignalSourceProvider): """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" raise NotImplementedError + @abstractmethod + def to_sfg(self) -> "SFG": + """Convert the operation into its corresponding SFG. + If the operation is composed by multiple operations, the operation will be split. + """ + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -256,6 +263,47 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Constant, Division return Division(Constant(src) if isinstance(src, Number) else src, self) + def __str__(self): + inputs_dict = dict() + for i, port in enumerate(self.inputs): + if port.signal_count == 0: + inputs_dict[i] = '-' + break + dict_ele = [] + for signal in port.signals: + if signal.source: + if signal.source.operation.graph_id: + dict_ele.append(signal.source.operation.graph_id) + else: + dict_ele.append("no_id") + else: + if signal.graph_id: + dict_ele.append(signal.graph_id) + else: + dict_ele.append("no_id") + inputs_dict[i] = dict_ele + + outputs_dict = dict() + for i, port in enumerate(self.outputs): + if port.signal_count == 0: + outputs_dict[i] = '-' + break + dict_ele = [] + for signal in port.signals: + if signal.destination: + if signal.destination.operation.graph_id: + dict_ele.append(signal.destination.operation.graph_id) + else: + dict_ele.append("no_id") + else: + if signal.graph_id: + dict_ele.append(signal.graph_id) + else: + dict_ele.append("no_id") + outputs_dict[i] = dict_ele + + return super().__str__() + f", \tinputs: {str(inputs_dict)}, \toutputs: {str(outputs_dict)}" + @property def input_count(self) -> int: return len(self._input_ports) @@ -361,6 +409,30 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] + def to_sfg(self) -> "SFG": + # Import here to avoid circular imports. + from b_asic.special_operations import Input, Output + from b_asic.signal_flow_graph import SFG + + inputs = [Input() for i in range(self.input_count)] + + try: + last_operations = self.evaluate(*inputs) + if isinstance(last_operations, Operation): + last_operations = [last_operations] + outputs = [Output(o) for o in last_operations] + except TypeError: + operation_copy = self.copy_component() + inputs = [] + for i in range(self.input_count): + _input = Input() + operation_copy.input(i).connect(_input) + inputs.append(_input) + + outputs = [Output(operation_copy)] + + return SFG(inputs=inputs, outputs=outputs) + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") @@ -370,6 +442,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): def neighbors(self) -> Iterable[GraphComponent]: return list(self.input_signals) + list(self.output_signals) + @property + def preceding_operations(self) -> Iterable[Operation]: + """Returns an Iterable of all Operations that are connected to this Operations input ports.""" + return [signal.source.operation for signal in self.input_signals if signal.source] + + @property + def subsequent_operations(self) -> Iterable[Operation]: + """Returns an Iterable of all Operations that are connected to this Operations output ports.""" + return [signal.destination.operation for signal in self.output_signals if signal.destination] + @property def source(self) -> OutputPort: if self.output_count != 1: diff --git a/b_asic/port.py b/b_asic/port.py index 59a218d9f8aa288d0aacb9dea15ca2cf0a604355..20783d5df0962b034aee2b6e934255a9fc9cd6e6 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -128,7 +128,7 @@ class InputPort(AbstractPort): signal.set_destination(self) def remove_signal(self, signal: Signal) -> None: - assert signal is self._source_signal, "Attempted to remove already removed signal." + assert signal is self._source_signal, "Attempted to remove signal that is not connected." self._source_signal = None signal.remove_destination() @@ -177,7 +177,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): signal.set_source(self) def remove_signal(self, signal: Signal) -> None: - assert signal in self._destination_signals, "Attempted to remove already removed signal." + assert signal in self._destination_signals, "Attempted to remove signal that is not connected." self._destination_signals.remove(signal) signal.remove_source() diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 6483cfc1476047bcbe897b871cc179990b894c4d..d51f13b4209fd1d5fc87e369b2f23dc8bf69301b 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -6,13 +6,16 @@ TODO: More info. from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet from numbers import Number from collections import defaultdict, deque +from io import StringIO +from queue import PriorityQueue +import itertools +from graphviz import Digraph -from b_asic.port import SignalSourceProvider, OutputPort, InputPort +from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation, MutableOutputMap, MutableRegisterMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output, Register -from b_asic.core_operations import Constant class GraphIDGenerator: @@ -41,8 +44,9 @@ class SFG(AbstractOperation): _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] - _components_ordered: List[GraphComponent] - _operations_ordered: List[Operation] + _components_dfs_order: List[GraphComponent] + _operations_dfs_order: List[Operation] + _operations_topological_order: List[Operation] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] @@ -67,8 +71,9 @@ class SFG(AbstractOperation): self._components_by_id = dict() self._components_by_name = defaultdict(list) - self._components_ordered = [] - self._operations_ordered = [] + self._components_dfs_order = [] + self._operations_dfs_order = [] + self._operations_topological_order = [] self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] @@ -151,9 +156,9 @@ class SFG(AbstractOperation): signal.destination.operation) elif new_signal.destination.operation in output_operations_set: # Add directly connected input to output to ordered list. - self._components_ordered.extend( + self._components_dfs_order.extend( [new_signal.source.operation, new_signal, new_signal.destination.operation]) - self._operations_ordered.extend( + self._operations_dfs_order.extend( [new_signal.source.operation, new_signal.destination.operation]) # Search the graph inwards from each output signal. @@ -170,47 +175,18 @@ class SFG(AbstractOperation): def __str__(self) -> str: """Get a string representation of this SFG.""" - output_string = "" - for component in self._components_ordered: - if isinstance(component, Operation): - for key, value in self._components_by_id.items(): - if value is component: - output_string += "id: " + key + ", name: " - - if component.name != None: - output_string += component.name + ", " - else: - output_string += "-, " + string_io = StringIO() + string_io.write(super().__str__() + "\n") + string_io.write("Internal Operations:\n") + line = "-" * 100 + "\n" + string_io.write(line) - if isinstance(component, Constant): - output_string += "value: " + \ - str(component.value) + ", input: [" - else: - output_string += "input: [" - - counter_input = 0 - for input in component.inputs: - counter_input += 1 - for signal in input.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - - if counter_input > 0: - output_string = output_string[:-2] - output_string += "], output: [" - counter_output = 0 - for output in component.outputs: - counter_output += 1 - for signal in output.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - if counter_output > 0: - output_string = output_string[:-2] - output_string += "]\n" - - return output_string + for operation in self.get_operations_topological_order(): + string_io.write(str(operation) + "\n") + + string_io.write(line) + + return string_io.getvalue() def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG": """Get a new independent SFG instance that is identical to this SFG except without any of its external connections.""" @@ -248,7 +224,7 @@ class SFG(AbstractOperation): return value def connect_external_signals_to_components(self) -> bool: - """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG + """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG it is a component off, causing it to become invalid afterwards. Returns True if succesful, False otherwise. """ if len(self.inputs) != len(self.input_operations): raise IndexError(f"Number of inputs does not match the number of input_operations in SFG.") @@ -264,7 +240,7 @@ class SFG(AbstractOperation): dest = input_operation.output(0).signals[0].destination dest.clear() port.signals[0].set_destination(dest) - # For each output_signal, connect it to the corresponding operation + # For each output_signal, connect it to the corresponding operation for port, output_operation in zip(self.outputs, self.output_operations): src = output_operation.input(0).signals[0].source src.clear() @@ -284,6 +260,9 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations + def to_sfg(self) -> 'SFG': + return self + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError( @@ -325,12 +304,12 @@ class SFG(AbstractOperation): @property def components(self) -> Iterable[GraphComponent]: """Get all components of this graph in depth-first order.""" - return self._components_ordered + return self._components_dfs_order @property def operations(self) -> Iterable[Operation]: """Get all operations of this graph in depth-first order.""" - return self._operations_ordered + return self._operations_dfs_order def get_components_with_type_name(self, type_name: TypeName) -> List[GraphComponent]: """Get a list with all components in this graph with the specified type_name. @@ -384,8 +363,8 @@ class SFG(AbstractOperation): new_op = None if original_op not in self._original_components_to_new: new_op = self._add_component_unconnected_copy(original_op) - self._components_ordered.append(new_op) - self._operations_ordered.append(new_op) + self._components_dfs_order.append(new_op) + self._operations_dfs_order.append(new_op) else: new_op = self._original_components_to_new[original_op] @@ -399,24 +378,20 @@ class SFG(AbstractOperation): if original_signal in self._original_input_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_destination( - new_op.input(original_input_port.index)) - self._components_ordered.extend( - [new_signal, new_signal.source.operation]) - self._operations_ordered.append( - new_signal.source.operation) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_dfs_order.extend([new_signal, new_signal.source.operation]) + self._operations_dfs_order.append(new_signal.source.operation) # Check if the signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") + raise ValueError("Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy( - original_signal) - new_signal.set_destination( - new_op.input(original_input_port.index)) - self._components_ordered.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_dfs_order.append(new_signal) original_connected_op = original_signal.source.operation # Check if connected Operation has been added before. @@ -426,12 +401,11 @@ class SFG(AbstractOperation): original_signal.source.index)) else: # Create new operation, set signal source to it. - new_connected_op = self._add_component_unconnected_copy( - original_connected_op) - new_signal.set_source(new_connected_op.output( - original_signal.source.index)) - self._components_ordered.append(new_connected_op) - self._operations_ordered.append(new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_source(new_connected_op.output(original_signal.source.index)) + + self._components_dfs_order.append(new_connected_op) + self._operations_dfs_order.append(new_connected_op) # Add connected operation to queue of operations to visit. op_stack.append(original_connected_op) @@ -443,24 +417,20 @@ class SFG(AbstractOperation): if original_signal in self._original_output_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_source( - new_op.output(original_output_port.index)) - self._components_ordered.extend( - [new_signal, new_signal.destination.operation]) - self._operations_ordered.append( - new_signal.destination.operation) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_dfs_order.extend([new_signal, new_signal.destination.operation]) + self._operations_dfs_order.append(new_signal.destination.operation) # Check if signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") + raise ValueError("Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy( - original_signal) - new_signal.set_source( - new_op.output(original_output_port.index)) - self._components_ordered.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_dfs_order.append(new_signal) original_connected_op = original_signal.destination.operation # Check if connected operation has been added. @@ -470,12 +440,11 @@ class SFG(AbstractOperation): original_signal.destination.index)) else: # Create new operation, set destination to it. - new_connected_op = self._add_component_unconnected_copy( - original_connected_op) - new_signal.set_destination(new_connected_op.input( - original_signal.destination.index)) - self._components_ordered.append(new_connected_op) - self._operations_ordered.append(new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + + self._components_dfs_order.append(new_connected_op) + self._operations_dfs_order.append(new_connected_op) # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) @@ -512,7 +481,7 @@ class SFG(AbstractOperation): # The old SFG will be deleted by Python GC return _sfg_copy() - def insert_operation(self, component: Operation, output_comp_id: GraphID): + def insert_operation(self, component: Operation, output_comp_id: GraphID) -> Optional["SFG"]: """Insert an operation in the SFG after a given source operation. The source operation output count must match the input count of the operation as well as the output Then return a new deepcopy of the sfg with the inserted component. @@ -543,6 +512,37 @@ class SFG(AbstractOperation): # Recreate the newly coupled SFG so that all attributes are correct. return sfg_copy() + def remove_operation(self, operation_id: GraphID) -> "SFG": + """Returns a version of the SFG where the operation with the specified GraphID removed. + The operation has to have the same amount of input- and output ports or a ValueError will + be raised. If no operation with the entered operation_id is found then returns None and does nothing.""" + sfg_copy = self() + operation = sfg_copy.find_by_id(operation_id) + if operation is None: + return None + + if operation.input_count != operation.output_count: + raise ValueError("Different number of input and output ports of operation with the specified id") + + for i, outport in enumerate(operation.outputs): + if outport.signal_count > 0: + if operation.input(i).signal_count > 0 and operation.input(i).signals[0].source is not None: + in_sig = operation.input(i).signals[0] + source_port = in_sig.source + source_port.remove_signal(in_sig) + operation.input(i).remove_signal(in_sig) + for out_sig in outport.signals.copy(): + out_sig.set_source(source_port) + else: + for out_sig in outport.signals.copy(): + out_sig.remove_source() + else: + if operation.input(i).signal_count > 0: + in_sig = operation.input(i).signals[0] + operation.input(i).remove_signal(in_sig) + + return sfg_copy() + def _evaluate_source(self, src: OutputPort, results: MutableOutputMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix if src_prefix: @@ -553,16 +553,13 @@ class SFG(AbstractOperation): if key in results: value = results[key] if value is None: - raise RuntimeError( - f"Direct feedback loop detected when evaluating operation.") + raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") return value - results[key] = src.operation.current_output( - src.index, registers, src_prefix) + results[key] = src.operation.current_output(src.index, registers, src_prefix) input_values = [self._evaluate_source( input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] - value = src.operation.evaluate_output( - src.index, input_values, results, registers, src_prefix) + value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) results[key] = value return value @@ -570,7 +567,7 @@ class SFG(AbstractOperation): """Returns a Precedence list of the SFG where each element in n:th the list consists of elements that are executed in the n:th step. If the precedence list already has been calculated for the current SFG then returns the cached version.""" - if self._precedence_list is not None: + if self._precedence_list: return self._precedence_list # Find all operations with only outputs and no inputs. @@ -584,17 +581,9 @@ class SFG(AbstractOperation): return self._precedence_list - def _traverse_for_precedence_list(self, first_iter_ports): + def _traverse_for_precedence_list(self, first_iter_ports: List[OutputPort]) -> List[List[OutputPort]]: # Find dependencies of output ports and input ports. - outports_per_inport = defaultdict(list) - remaining_inports_per_outport = dict() - for op in self.operations: - op_inputs = op.inputs - for out_i, outport in enumerate(op.outputs): - dependendent_indexes = op.inputs_required_for_output(out_i) - remaining_inports_per_outport[outport] = len(dependendent_indexes) - for in_i in dependendent_indexes: - outports_per_inport[op_inputs[in_i]].append(outport) + remaining_inports_per_operation = {op: op.input_count for op in self.operations} # Traverse output ports for precedence curr_iter_ports = first_iter_ports @@ -611,11 +600,137 @@ class SFG(AbstractOperation): new_inport = signal.destination # Don't traverse over Registers if new_inport is not None and not isinstance(new_inport.operation, Register): - for new_outport in outports_per_inport[new_inport]: - remaining_inports_per_outport[new_outport] -= 1 - if remaining_inports_per_outport[new_outport] == 0: - next_iter_ports.append(new_outport) + new_op = new_inport.operation + remaining_inports_per_operation[new_op] -= 1 + if remaining_inports_per_operation[new_op] == 0: + next_iter_ports.extend(new_op.outputs) curr_iter_ports = next_iter_ports return precedence_list + + def show_precedence_graph(self) -> None: + p_list = self.get_precedence_list() + pg = Digraph() + pg.attr(rankdir = 'LR') + + # Creates nodes for each output port in the precedence list + for i in range(len(p_list)): + ports = p_list[i] + with pg.subgraph(name='cluster_' + str(i)) as sub: + sub.attr(label='N' + str(i + 1)) + for port in ports: + sub.node(port.operation.graph_id + '.' + str(port.index)) + # Creates edges for each output port and creates nodes for each operation and edges for them as well + for i in range(len(p_list)): + ports = p_list[i] + for port in ports: + for signal in port.signals: + pg.edge(port.operation.graph_id + '.' + str(port.index), signal.destination.operation.graph_id) + pg.node(signal.destination.operation.graph_id, shape = 'square') + pg.edge(port.operation.graph_id, port.operation.graph_id + '.' + str(port.index)) + pg.node(port.operation.graph_id, shape = 'square') + + pg.view() + + def print_precedence_graph(self) -> None: + """Prints a representation of the SFG's precedence list to the standard out. + If the precedence list already has been calculated then it uses the cached version, + otherwise it calculates the precedence list and then prints it.""" + precedence_list = self.get_precedence_list() + + line = "-" * 120 + out_str = StringIO() + out_str.write(line) + + printed_ops = set() + + for iter_num, iter in enumerate(precedence_list, start=1): + for outport_num, outport in enumerate(iter, start=1): + if outport not in printed_ops: + # Only print once per operation, even if it has multiple outports + out_str.write("\n") + out_str.write(str(iter_num)) + out_str.write(".") + out_str.write(str(outport_num)) + out_str.write(" \t") + out_str.write(str(outport.operation)) + printed_ops.add(outport) + + out_str.write("\n") + out_str.write(line) + + print(out_str.getvalue()) + + def get_operations_topological_order(self) -> Iterable[Operation]: + """Returns an Iterable of the Operations in the SFG in Topological Order. + Feedback loops makes an absolutely correct Topological order impossible, so an + approximative Topological Order is returned in such cases in this implementation.""" + if self._operations_topological_order: + return self._operations_topological_order + + no_inputs_queue = deque(list(filter(lambda op: op.input_count == 0, self.operations))) + remaining_inports_per_operation = {op: op.input_count for op in self.operations} + + # Maps number of input counts to a queue of seen objects with such a size. + seen_with_inputs_dict = defaultdict(deque) + seen = set() + top_order = [] + + assert len(no_inputs_queue) > 0, "Illegal SFG state, dangling signals in SFG." + + first_op = no_inputs_queue.popleft() + visited = set([first_op]) + p_queue = PriorityQueue() + p_queue.put((-first_op.output_count, first_op)) # Negative priority as max-heap popping is wanted + operations_left = len(self.operations) - 1 + + seen_but_not_visited_count = 0 + + while operations_left > 0: + while not p_queue.empty(): + op = p_queue.get()[1] + + operations_left -= 1 + top_order.append(op) + visited.add(op) + + for neighbor_op in op.subsequent_operations: + if neighbor_op not in visited: + remaining_inports_per_operation[neighbor_op] -= 1 + remaining_inports = remaining_inports_per_operation[neighbor_op] + + if remaining_inports == 0: + p_queue.put((-neighbor_op.output_count, neighbor_op)) + + elif remaining_inports > 0: + if neighbor_op in seen: + seen_with_inputs_dict[remaining_inports + 1].remove(neighbor_op) + else: + seen.add(neighbor_op) + seen_but_not_visited_count += 1 + + seen_with_inputs_dict[remaining_inports].append(neighbor_op) + + # Check if have to fetch Operations from somewhere else since p_queue is empty + if operations_left > 0: + # First check if can fetch from Operations with no input ports + if no_inputs_queue: + new_op = no_inputs_queue.popleft() + p_queue.put((new_op.output_count, new_op)) + + # Else fetch operation with lowest input count that is not zero + elif seen_but_not_visited_count > 0: + for i in itertools.count(start=1): + seen_inputs_queue = seen_with_inputs_dict[i] + if seen_inputs_queue: + new_op = seen_inputs_queue.popleft() + p_queue.put((-new_op.output_count, new_op)) + seen_but_not_visited_count -= 1 + break + else: + raise RuntimeError("Unallowed structure in SFG detected") + + self._operations_topological_order = top_order + + return self._operations_topological_order diff --git a/setup.py b/setup.py index 43d55d40a95212196facb973ebc97a1bdc5e7f42..94285a70c496dc08923a9a36258316e8397670c6 100644 --- a/setup.py +++ b/setup.py @@ -36,9 +36,9 @@ class CMakeBuild(build_ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - + env = os.environ.copy() - + print(f"=== Configuring {ext.name} ===") print(f"Temp dir: {self.build_temp}") print(f"Output dir: {cmake_output_dir}") @@ -71,7 +71,8 @@ setuptools.setup( install_requires = [ "pybind11>=2.3.0", "numpy", - "install_qt_binding" + "pyside2", + "graphviz" ], packages = ["b_asic"], ext_modules = [CMakeExtension("b_asic")], diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index 5a0ef25b94cec8e3fad9275cccf97882703de330..08d9e8aa2bacd0b1c1a11c17c174179d853e6ed7 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -1,12 +1,12 @@ import pytest -from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication +from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication, Addition, Butterfly @pytest.fixture def sfg_two_inputs_two_outputs(): """Valid SFG with two inputs and two outputs. - . . + . . in1-------+ +--------->out1 . | | . . v | . @@ -17,9 +17,9 @@ def sfg_two_inputs_two_outputs(): | . ^ . | . | . +------------+ . - . . + . . out1 = in1 + in2 - out2 = in1 + 2 * in2 + out2 = in1 + 2 * in2 """ in1 = Input() in2 = Input() @@ -27,13 +27,14 @@ def sfg_two_inputs_two_outputs(): add2 = add1 + in2 out1 = Output(add1) out2 = Output(add2) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) + return SFG(inputs=[in1, in2], outputs=[out1, out2]) + @pytest.fixture def sfg_two_inputs_two_outputs_independent(): """Valid SFG with two inputs and two outputs, where the first output only depends on the first input and the second output only depends on the second input. - . . + . . in1-------------------->out1 . . . . @@ -44,17 +45,18 @@ def sfg_two_inputs_two_outputs_independent(): . | ^ . . | | . . +------+ . - . . + . . out1 = in1 - out2 = in2 + 3 + out2 = in2 + 3 """ - in1 = Input() - in2 = Input() - c1 = Constant(3) - add1 = in2 + c1 - out1 = Output(in1) - out2 = Output(add1) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) + in1 = Input("IN1") + in2 = Input("IN2") + c1 = Constant(3, "C1") + add1 = Addition(in2, c1, "ADD1") + out1 = Output(in1, "OUT1") + out2 = Output(add1, "OUT2") + return SFG(inputs=[in1, in2], outputs=[out1, out2]) + @pytest.fixture def sfg_nested(): @@ -65,7 +67,7 @@ def sfg_nested(): mac_in2 = Input() mac_in3 = Input() mac_out1 = Output(mac_in1 + mac_in2 * mac_in3) - MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs = [mac_out1]) + MAC = SFG(inputs=[mac_in1, mac_in2, mac_in3], outputs=[mac_out1]) in1 = Input() in2 = Input() @@ -73,7 +75,8 @@ def sfg_nested(): mac2 = MAC(in1, in2, mac1) mac3 = MAC(in1, mac1, mac2) out1 = Output(mac3) - return SFG(inputs = [in1, in2], outputs = [out1]) + return SFG(inputs=[in1, in2], outputs=[out1]) + @pytest.fixture def sfg_delay(): @@ -83,7 +86,8 @@ def sfg_delay(): in1 = Input() reg1 = Register(in1) out1 = Output(reg1) - return SFG(inputs = [in1], outputs = [out1]) + return SFG(inputs=[in1], outputs=[out1]) + @pytest.fixture def sfg_accumulator(): @@ -95,7 +99,8 @@ def sfg_accumulator(): reg = Register() reg.input(0).connect((reg + data_in) * (1 - reset)) data_out = Output(reg) - return SFG(inputs = [data_in, reset], outputs = [data_out]) + return SFG(inputs=[data_in, reset], outputs=[data_out]) + @pytest.fixture def simple_filter(): @@ -105,11 +110,71 @@ def simple_filter(): | | in1>------add1>------reg>------+------out1> """ - in1 = Input() - reg = Register() - constmul1 = ConstantMultiplication(0.5) - add1 = in1 + constmul1 - reg.input(0).connect(add1) - constmul1.input(0).connect(reg) - out1 = Output(reg) - return SFG(inputs=[in1], outputs=[out1]) + in1 = Input("IN1") + constmul1 = ConstantMultiplication(0.5, name="CMUL1") + add1 = Addition(in1, constmul1, "ADD1") + add1.input(1).signals[0].name = "S2" + reg = Register(add1, name="REG1") + constmul1.input(0).connect(reg, "S1") + out1 = Output(reg, "OUT1") + return SFG(inputs=[in1], outputs=[out1], name="simple_filter") + + +@pytest.fixture +def precedence_sfg_registers(): + """A sfg with registers and interesting layout for precednce list generation. + + IN1>--->C0>--->ADD1>--->Q1>---+--->A0>--->ADD4>--->OUT1 + ^ | ^ + | T1 | + | | | + ADD2<---<B1<---+--->A1>--->ADD3 + ^ | ^ + | T2 | + | | | + +-----<B2<---+--->A2>-----+ + """ + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + T2 = Register(T1, 0, "T2") + b2 = ConstantMultiplication(2, T2, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(6, T2, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + add4 = Addition(a0, add3, "ADD4") + out1 = Output(add4, "OUT1") + + return SFG(inputs=[in1], outputs=[out1], name="SFG") + + +@pytest.fixture +def precedence_sfg_registers_and_constants(): + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + const1 = Constant(10, "CONST1") # Replace T2 register with a constant + b2 = ConstantMultiplication(2, const1, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(10, const1, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + # Replace ADD4 with a butterfly to test multiple output ports + bfly1 = Butterfly(a0, add3, "BFLY1") + out1 = Output(bfly1.output(0), "OUT1") + out2 = Output(bfly1.output(1), "OUT2") + + return SFG(inputs=[in1], outputs=[out1], name="SFG") diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index 5423ecdf08c420df5dccc6393c3ad6637961172b..9163fce2a955c7fbc68d5d24de86896d251934da 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -89,4 +89,3 @@ def test_division_overload(): assert isinstance(div3, Division) assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(1).signals == div2.output(0).signals - diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 2eb341da88a851ac0fd26939da64377ea27963a1..6a0493c60965579bd843e0b514bd7f9b9a0e4707 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -6,7 +6,6 @@ from b_asic import \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly - class TestConstant: def test_constant_positive(self): test_operation = Constant(3) diff --git a/test/test_operation.py b/test/test_operation.py index b76ba16d11425c0ce868e4fa0b4c88d9f862e23f..77e9ba3cbd0eaa75886b5a7e5d11f00f6cfeb479 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Constant, Addition +from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot class TestTraverse: def test_traverse_single_tree(self, operation): @@ -22,4 +22,32 @@ class TestTraverse: assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 def test_traverse_loop(self, operation_graph_with_cycle): - assert len(list(operation_graph_with_cycle.traverse())) == 8 \ No newline at end of file + assert len(list(operation_graph_with_cycle.traverse())) == 8 + +class TestToSfg: + def test_convert_mad_to_sfg(self): + mad1 = MAD() + mad1_sfg = mad1.to_sfg() + + assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1) + assert len(mad1_sfg.operations) == 6 + + def test_butterfly_to_sfg(self): + but1 = Butterfly() + but1_sfg = but1.to_sfg() + + assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0] + assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1] + assert len(but1_sfg.operations) == 8 + + def test_add_to_sfg(self): + add1 = Addition() + add1_sfg = add1.to_sfg() + + assert len(add1_sfg.operations) == 4 + + def test_sqrt_to_sfg(self): + sqrt1 = SquareRoot() + sqrt1_sfg = sqrt1.to_sfg() + + assert len(sqrt1_sfg.operations) == 3 diff --git a/test/test_sfg.py b/test/test_sfg.py index 5f86739517b0d4c7bc9b242de24c3777222b51d2..a27b404e8d2ac0eaf6b4146ed497e2f06b5973cf 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,4 +1,7 @@ import pytest +import io +import sys + from b_asic import SFG, Signal, Input, Output, Constant, ConstantMultiplication, Addition, Multiplication, Register, \ Butterfly, Subtraction, SquareRoot @@ -54,13 +57,17 @@ class TestPrintSfg: inp2 = Input("INP2") add1 = Addition(inp1, inp2, "ADD1") out1 = Output(add1, "OUT1") - sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") + sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="SFG1") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s1, s2], output: [s3]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: in2, name: INP2, input: [], output: [s2]\n" + \ - "id: out1, name: OUT1, input: [s3], output: []\n" + "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("INP2")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_add_mul(self): inp1 = Input("INP1") @@ -72,12 +79,16 @@ class TestPrintSfg: sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s1, s2], output: [s5]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: in2, name: INP2, input: [], output: [s2]\n" + \ - "id: mul1, name: MUL1, input: [s5, s3], output: [s4]\n" + \ - "id: in3, name: INP3, input: [], output: [s3]\n" + \ - "id: out1, name: OUT1, input: [s4], output: []\n" + "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("INP2")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("INP3")[0]) + "\n" + \ + str(sfg.find_by_name("MUL1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_constant(self): inp1 = Input("INP1") @@ -88,18 +99,27 @@ class TestPrintSfg: sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s3, s1], output: [s2]\n" + \ - "id: c1, name: CONST, value: 3, input: [], output: [s3]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: out1, name: OUT1, input: [s2], output: []\n" + "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("CONST")[0]) + "\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_simple_filter(self, simple_filter): + assert simple_filter.__str__() == \ - 'id: add1, name: , input: [s1, s3], output: [s4]\n' + \ - 'id: in1, name: , input: [], output: [s1]\n' + \ - 'id: cmul1, name: , input: [s5], output: [s3]\n' + \ - 'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \ - 'id: out1, name: , input: [s2], output: []\n' + "id: no_id, \tname: simple_filter, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(simple_filter.find_by_name("IN1")[0]) + "\n" + \ + str(simple_filter.find_by_name("ADD1")[0]) + "\n" + \ + str(simple_filter.find_by_name("REG1")[0]) + "\n" + \ + str(simple_filter.find_by_name("CMUL1")[0]) + "\n" + \ + str(simple_filter.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" class TestDeepCopy: @@ -267,7 +287,7 @@ class TestInsertComponent: _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) assert _sfg.evaluate() != sfg.evaluate() - + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) @@ -275,7 +295,8 @@ class TestInsertComponent: assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3") - assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3") + assert _sfg.find_by_name("constant4")[0].output( + 0).signals[0].destination.operation is not _sfg.find_by_id("add3") assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3") def test_insert_invalid_component_in_sfg(self, large_operation_tree): @@ -304,22 +325,26 @@ class TestInsertComponent: assert len(_sfg.find_by_name("n_bfly")) == 1 # Correctly connected old output -> new input - assert _sfg.find_by_name("bfly3")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] - assert _sfg.find_by_name("bfly3")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output( + 0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output( + 1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] # Correctly connected new input -> old output assert _sfg.find_by_name("n_bfly")[0].input(0).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] assert _sfg.find_by_name("n_bfly")[0].input(1).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] # Correctly connected new output -> next input - assert _sfg.find_by_name("n_bfly")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] - assert _sfg.find_by_name("n_bfly")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output( + 0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output( + 1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] # Correctly connected next input -> new output assert _sfg.find_by_name("bfly2")[0].input(0).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] assert _sfg.find_by_name("bfly2")[0].input(1).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] - + class TestFindComponentsWithTypeName: def test_mac_components(self): inp1 = Input("INP1") @@ -358,28 +383,9 @@ class TestFindComponentsWithTypeName: class TestGetPrecedenceList: - def test_inputs_registers(self): - in1 = Input("IN1") - c0 = ConstantMultiplication(5, in1, "C0") - add1 = Addition(c0, None, "ADD1") - # Not sure what operation "Q" is supposed to be in the example - Q1 = ConstantMultiplication(1, add1, "Q1") - T1 = Register(Q1, 0, "T1") - T2 = Register(T1, 0, "T2") - b2 = ConstantMultiplication(2, T2, "B2") - b1 = ConstantMultiplication(3, T1, "B1") - add2 = Addition(b1, b2, "ADD2") - add1.input(1).connect(add2) - a1 = ConstantMultiplication(4, T1, "A1") - a2 = ConstantMultiplication(6, T2, "A2") - add3 = Addition(a1, a2, "ADD3") - a0 = ConstantMultiplication(7, Q1, "A0") - add4 = Addition(a0, add3, "ADD4") - out1 = Output(add4, "OUT1") - - sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + def test_inputs_registers(self, precedence_sfg_registers): - precedence_list = sfg.get_precedence_list() + precedence_list = precedence_sfg_registers.get_precedence_list() assert len(precedence_list) == 7 @@ -404,30 +410,9 @@ class TestGetPrecedenceList: assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[6]]) == {"ADD4"} - def test_inputs_constants_registers_multiple_outputs(self): - in1 = Input("IN1") - c0 = ConstantMultiplication(5, in1, "C0") - add1 = Addition(c0, None, "ADD1") - # Not sure what operation "Q" is supposed to be in the example - Q1 = ConstantMultiplication(1, add1, "Q1") - T1 = Register(Q1, 0, "T1") - const1 = Constant(10, "CONST1") # Replace T2 register with a constant - b2 = ConstantMultiplication(2, const1, "B2") - b1 = ConstantMultiplication(3, T1, "B1") - add2 = Addition(b1, b2, "ADD2") - add1.input(1).connect(add2) - a1 = ConstantMultiplication(4, T1, "A1") - a2 = ConstantMultiplication(10, const1, "A2") - add3 = Addition(a1, a2, "ADD3") - a0 = ConstantMultiplication(7, Q1, "A0") - # Replace ADD4 with a butterfly to test multiple output ports - bfly1 = Butterfly(a0, add3, "BFLY1") - out1 = Output(bfly1.output(0), "OUT1") - out2 = Output(bfly1.output(1), "OUT2") - - sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + def test_inputs_constants_registers_multiple_outputs(self, precedence_sfg_registers_and_constants): - precedence_list = sfg.get_precedence_list() + precedence_list = precedence_sfg_registers_and_constants.get_precedence_list() assert len(precedence_list) == 7 @@ -502,10 +487,48 @@ class TestGetPrecedenceList: for port in precedence_list[0]]) == {"IN1", "IN2"} assert set([port.operation.key(port.index, port.operation.name) - for port in precedence_list[1]]) == {"NESTED_SFG.0", "CMUL1"} + for port in precedence_list[1]]) == {"CMUL1"} assert set([port.operation.key(port.index, port.operation.name) - for port in precedence_list[2]]) == {"NESTED_SFG.1"} + for port in precedence_list[2]]) == {"NESTED_SFG.0", "NESTED_SFG.1"} + + +class TestPrintPrecedence: + def test_registers(self, precedence_sfg_registers): + sfg = precedence_sfg_registers + + captured_output = io.StringIO() + sys.stdout = captured_output + + sfg.print_precedence_graph() + + sys.stdout = sys.__stdout__ + + captured_output = captured_output.getvalue() + + assert captured_output == \ + "-" * 120 + "\n" + \ + "1.1 \t" + str(sfg.find_by_name("IN1")[0]) + "\n" + \ + "1.2 \t" + str(sfg.find_by_name("T1")[0]) + "\n" + \ + "1.3 \t" + str(sfg.find_by_name("T2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "2.1 \t" + str(sfg.find_by_name("C0")[0]) + "\n" + \ + "2.2 \t" + str(sfg.find_by_name("A1")[0]) + "\n" + \ + "2.3 \t" + str(sfg.find_by_name("B1")[0]) + "\n" + \ + "2.4 \t" + str(sfg.find_by_name("A2")[0]) + "\n" + \ + "2.5 \t" + str(sfg.find_by_name("B2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "3.1 \t" + str(sfg.find_by_name("ADD3")[0]) + "\n" + \ + "3.2 \t" + str(sfg.find_by_name("ADD2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "4.1 \t" + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "5.1 \t" + str(sfg.find_by_name("Q1")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "6.1 \t" + str(sfg.find_by_name("A0")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "7.1 \t" + str(sfg.find_by_name("ADD4")[0]) + "\n" + \ + "-" * 120 + "\n" class TestDepends: @@ -668,7 +691,88 @@ class TestConnectExternalSignalsToComponentsMultipleComp: out1.input(0).connect(sub1, "S7") test_sfg = SFG(inputs=[inp1, inp2, inp3, inp4], outputs=[out1]) + assert test_sfg.evaluate(1, 2, 3, 4) == 16 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2, 3, 4) == 16 assert not test_sfg.connect_external_signals_to_components() + +class TestTopologicalOrderOperations: + def test_feedback_sfg(self, simple_filter): + topological_order = simple_filter.get_operations_topological_order() + + assert [comp.name for comp in topological_order] == ["IN1", "ADD1", "REG1", "CMUL1", "OUT1"] + + def test_multiple_independent_inputs(self, sfg_two_inputs_two_outputs_independent): + topological_order = sfg_two_inputs_two_outputs_independent.get_operations_topological_order() + + assert [comp.name for comp in topological_order] == ["IN1", "OUT1", "IN2", "C1", "ADD1", "OUT2"] + + +class TestRemove: + def test_remove_single_input_outputs(self, simple_filter): + new_sfg = simple_filter.remove_operation("cmul1") + + assert set(op.name for op in simple_filter.find_by_name("REG1")[0].subsequent_operations) == {"CMUL1", "OUT1"} + assert set(op.name for op in new_sfg.find_by_name("REG1")[0].subsequent_operations) == {"ADD1", "OUT1"} + + assert set(op.name for op in simple_filter.find_by_name("ADD1")[0].preceding_operations) == {"CMUL1", "IN1"} + assert set(op.name for op in new_sfg.find_by_name("ADD1")[0].preceding_operations) == {"REG1", "IN1"} + + assert "S1" in set([sig.name for sig in simple_filter.find_by_name("REG1")[0].output(0).signals]) + assert "S2" in set([sig.name for sig in new_sfg.find_by_name("REG1")[0].output(0).signals]) + + def test_remove_multiple_inputs_outputs(self, butterfly_operation_tree): + out1 = Output(butterfly_operation_tree.output(0), "OUT1") + out2 = Output(butterfly_operation_tree.output(1), "OUT2") + + sfg = SFG(outputs=[out1, out2]) + + new_sfg = sfg.remove_operation(sfg.find_by_name("bfly2")[0].graph_id) + + assert sfg.find_by_name("bfly3")[0].output(0).signal_count == 1 + assert new_sfg.find_by_name("bfly3")[0].output(0).signal_count == 1 + + sfg_dest_0 = sfg.find_by_name("bfly3")[0].output(0).signals[0].destination + new_sfg_dest_0 = new_sfg.find_by_name("bfly3")[0].output(0).signals[0].destination + + assert sfg_dest_0.index == 0 + assert new_sfg_dest_0.index == 0 + assert sfg_dest_0.operation.name == "bfly2" + assert new_sfg_dest_0.operation.name == "bfly1" + + assert sfg.find_by_name("bfly3")[0].output(1).signal_count == 1 + assert new_sfg.find_by_name("bfly3")[0].output(1).signal_count == 1 + + sfg_dest_1 = sfg.find_by_name("bfly3")[0].output(1).signals[0].destination + new_sfg_dest_1 = new_sfg.find_by_name("bfly3")[0].output(1).signals[0].destination + + assert sfg_dest_1.index == 1 + assert new_sfg_dest_1.index == 1 + assert sfg_dest_1.operation.name == "bfly2" + assert new_sfg_dest_1.operation.name == "bfly1" + + assert sfg.find_by_name("bfly1")[0].input(0).signal_count == 1 + assert new_sfg.find_by_name("bfly1")[0].input(0).signal_count == 1 + + sfg_source_0 = sfg.find_by_name("bfly1")[0].input(0).signals[0].source + new_sfg_source_0 = new_sfg.find_by_name("bfly1")[0].input(0).signals[0].source + + assert sfg_source_0.index == 0 + assert new_sfg_source_0.index == 0 + assert sfg_source_0.operation.name == "bfly2" + assert new_sfg_source_0.operation.name == "bfly3" + + sfg_source_1 = sfg.find_by_name("bfly1")[0].input(1).signals[0].source + new_sfg_source_1 = new_sfg.find_by_name("bfly1")[0].input(1).signals[0].source + + assert sfg_source_1.index == 1 + assert new_sfg_source_1.index == 1 + assert sfg_source_1.operation.name == "bfly2" + assert new_sfg_source_1.operation.name == "bfly3" + + assert "bfly2" not in set(op.name for op in new_sfg.operations) + + def remove_different_number_inputs_outputs(self, simple_filter): + with pytest.raises(ValueError): + simple_filter.remove_operation("add1")