From 60587da50bf169fa33d052e042ef956c44416a1c Mon Sep 17 00:00:00 2001 From: Simon Bjurek <simbj106@student.liu.se> Date: Sun, 16 Mar 2025 14:20:10 +0000 Subject: [PATCH] Added ruff bugbear and simplify rules and reformatted code accordingly --- b_asic/GUI/main_window.py | 23 ++++---- b_asic/architecture.py | 6 +- b_asic/codegen/vhdl/entity.py | 2 +- b_asic/operation.py | 4 +- b_asic/port.py | 2 +- b_asic/process.py | 4 +- b_asic/quantization.py | 5 +- b_asic/research/interleaver.py | 8 +-- b_asic/resources.py | 41 +++++-------- b_asic/schedule.py | 10 ++-- b_asic/scheduler.py | 76 +++++++++++++------------ b_asic/scheduler_gui/compile.py | 8 +-- b_asic/scheduler_gui/main_window.py | 32 ++++------- b_asic/scheduler_gui/scheduler_event.py | 12 ++-- b_asic/scheduler_gui/scheduler_item.py | 9 +-- b_asic/sfg_generators.py | 4 +- b_asic/signal_flow_graph.py | 52 +++++++++++------ b_asic/utils.py | 2 +- pyproject.toml | 6 +- test/unit/test_list_schedulers.py | 2 +- test/unit/test_sfg.py | 14 ++--- test/unit/test_signal.py | 2 +- 22 files changed, 159 insertions(+), 165 deletions(-) diff --git a/b_asic/GUI/main_window.py b/b_asic/GUI/main_window.py index 15b5f98c..d26b14ad 100644 --- a/b_asic/GUI/main_window.py +++ b/b_asic/GUI/main_window.py @@ -382,7 +382,7 @@ class SFGMainWindow(QMainWindow): self.update() def _create_recent_file_actions_and_menus(self): - for i in range(self._max_recent_files): + for _ in range(self._max_recent_files): recent_file_action = QAction(self._ui.recent_sfg) recent_file_action.setVisible(False) recent_file_action.triggered.connect( @@ -510,24 +510,24 @@ class SFGMainWindow(QMainWindow): and hasattr(source_operation2, "value") and hasattr(dest_operation, "value") and hasattr(dest_operation2, "value") - ): - if not ( + and not ( source_operation.value == source_operation2.value and dest_operation.value == dest_operation2.value - ): - return False + ) + ): + return False if ( hasattr(source_operation, "name") and hasattr(source_operation2, "name") and hasattr(dest_operation, "name") and hasattr(dest_operation2, "name") - ): - if not ( + and not ( source_operation.name == source_operation2.name and dest_operation.name == dest_operation2.name - ): - return False + ) + ): + return False try: signal_source_index = [ @@ -744,9 +744,8 @@ class SFGMainWindow(QMainWindow): operation_label.moveBy(10, -20) attr_button.add_label(operation_label) - if isinstance(is_flipped, bool): - if is_flipped: - attr_button._flip() + if isinstance(is_flipped, bool) and is_flipped: + attr_button._flip() self._drag_buttons[op] = attr_button self._drag_operation_scenes[attr_button] = attr_button_scene diff --git a/b_asic/architecture.py b/b_asic/architecture.py index 3d18f17b..419181d5 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -558,6 +558,8 @@ class Memory(Resource): if self._memory_type == "RAM": plural_s = 's' if len(self._assignment) >= 2 else '' return f": (RAM, {len(self._assignment)} cell{plural_s})" + else: + pass return "" def assign(self, heuristic: str = "left_edge") -> None: @@ -1036,7 +1038,7 @@ of :class:`~b_asic.architecture.ProcessingElement` fontname='Times New Roman', ) else: - for i, mem in enumerate(self._memories): + for mem in self._memories: dg.node( mem.entity_name, mem._struct_def(), @@ -1044,7 +1046,7 @@ of :class:`~b_asic.architecture.ProcessingElement` fillcolor=memory_color, fontname='Times New Roman', ) - for i, pe in enumerate(self._processing_elements): + for pe in self._processing_elements: dg.node( pe.entity_name, pe._struct_def(), style='filled', fillcolor=pe_color ) diff --git a/b_asic/codegen/vhdl/entity.py b/b_asic/codegen/vhdl/entity.py index c52f8584..40a7caf7 100644 --- a/b_asic/codegen/vhdl/entity.py +++ b/b_asic/codegen/vhdl/entity.py @@ -57,7 +57,7 @@ def memory_based_storage( read_ports: set[Port] = { read_port for mv in collection for read_port in mv.read_ports } # type: ignore - for idx, read_port in enumerate(read_ports): + for read_port in read_ports: port_name = read_port if isinstance(read_port, int) else read_port.name port_name = 'p_' + str(port_name) + '_in' f.write(f'{2*VHDL_TAB}{port_name} : in std_logic_vector(WL-1 downto 0);\n') diff --git a/b_asic/operation.py b/b_asic/operation.py index bc6bb748..98c46308 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -1046,9 +1046,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): @property def is_linear(self) -> bool: # doc-string inherited - if self.is_constant: - return True - return False + return self.is_constant @property def is_constant(self) -> bool: diff --git a/b_asic/port.py b/b_asic/port.py index 65df10d3..e12ae556 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -359,7 +359,7 @@ class InputPort(AbstractPort): tmp_signal = self.signals[0] tmp_signal.remove_destination() current = self - for i in range(number): + for _ in range(number): d = Delay() current.connect(d) current = d.input(0) diff --git a/b_asic/process.py b/b_asic/process.py index 0a5d7fa7..d165a13e 100644 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -319,7 +319,7 @@ class MemoryVariable(MemoryProcess): return self._write_port def __repr__(self) -> str: - reads = {k: v for k, v in zip(self._read_ports, self._life_times)} + reads = {k: v for k, v in zip(self._read_ports, self._life_times, strict=True)} return ( f"MemoryVariable({self.start_time}, {self.write_port}," f" {reads!r}, {self.name!r})" @@ -413,7 +413,7 @@ class PlainMemoryVariable(MemoryProcess): return self._write_port def __repr__(self) -> str: - reads = {k: v for k, v in zip(self._read_ports, self._life_times)} + reads = {k: v for k, v in zip(self._read_ports, self._life_times, strict=True)} return ( f"PlainMemoryVariable({self.start_time}, {self.write_port}," f" {reads!r}, {self.name!r})" diff --git a/b_asic/quantization.py b/b_asic/quantization.py index 26cc2c70..f53b1e94 100644 --- a/b_asic/quantization.py +++ b/b_asic/quantization.py @@ -124,10 +124,7 @@ def quantize( elif quantization is Quantization.ROUNDING: v = math.floor(v + 0.5) elif quantization is Quantization.MAGNITUDE_TRUNCATION: - if v >= 0: - v = math.floor(v) - else: - v = math.ceil(v) + v = math.floor(v) if v >= 0 else math.ceil(v) elif quantization is Quantization.JAMMING: v = math.floor(v) | 1 elif quantization is Quantization.UNBIASED_ROUNDING: diff --git a/b_asic/research/interleaver.py b/b_asic/research/interleaver.py index 86b81491..296accea 100644 --- a/b_asic/research/interleaver.py +++ b/b_asic/research/interleaver.py @@ -20,10 +20,10 @@ def _insert_delays( maxdiff = min(outputorder[i][0] - inputorder[i][0] for i in range(size)) outputorder = [(o[0] - maxdiff + min_lifetime, o[1]) for o in outputorder] maxdelay = max(outputorder[i][0] - inputorder[i][0] for i in range(size)) - if cyclic: - if maxdelay >= time: - inputorder = inputorder + [(i[0] + time, i[1]) for i in inputorder] - outputorder = outputorder + [(o[0] + time, o[1]) for o in outputorder] + + if cyclic and maxdelay >= time: + inputorder = inputorder + [(i[0] + time, i[1]) for i in inputorder] + outputorder = outputorder + [(o[0] + time, o[1]) for o in outputorder] return inputorder, outputorder diff --git a/b_asic/resources.py b/b_asic/resources.py index 1af2b9b0..53a5ef25 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -1340,29 +1340,18 @@ class ProcessCollection: raise ValueError(f'{mv!r} is not part of {self!r}.') # Make sure that concurrent reads/writes do not surpass the port setting - for mv in self: - - def filter_write(p): - return p.start_time == mv.start_time - - def filter_read(p): - return ( - (p.start_time + p.execution_time) % self._schedule_time - == mv.start_time + mv.execution_time % self._schedule_time - ) - - needed_write_ports = len(list(filter(filter_write, self))) - needed_read_ports = len(list(filter(filter_read, self))) - if needed_write_ports > write_ports + 1: - raise ValueError( - f'More than {write_ports} write ports needed ({needed_write_ports})' - ' to generate HDL for this ProcessCollection' - ) - if needed_read_ports > read_ports + 1: - raise ValueError( - f'More than {read_ports} read ports needed ({needed_read_ports}) to' - ' generate HDL for this ProcessCollection' - ) + needed_write_ports = self.read_ports_bound() + needed_read_ports = self.write_ports_bound() + if needed_write_ports > write_ports + 1: + raise ValueError( + f'More than {write_ports} write ports needed ({needed_write_ports})' + ' to generate HDL for this ProcessCollection' + ) + if needed_read_ports > read_ports + 1: + raise ValueError( + f'More than {read_ports} read ports needed ({needed_read_ports}) to' + ' generate HDL for this ProcessCollection' + ) # Sanitize the address logic pipeline settings if adr_mux_size is not None and adr_pipe_depth is not None: @@ -1648,11 +1637,11 @@ class ProcessCollection: axes : list of three :class:`matplotlib.axes.Axes` Three Axes to plot in. """ - axes[0].bar(*zip(*self.read_port_accesses().items())) + axes[0].bar(*zip(*self.read_port_accesses().items(), strict=True)) axes[0].set_title("Read port accesses") - axes[1].bar(*zip(*self.write_port_accesses().items())) + axes[1].bar(*zip(*self.write_port_accesses().items(), strict=True)) axes[1].set_title("Write port accesses") - axes[2].bar(*zip(*self.total_port_accesses().items())) + axes[2].bar(*zip(*self.total_port_accesses().items(), strict=True)) axes[2].set_title("Total port accesses") for ax in axes: ax.xaxis.set_major_locator(MaxNLocator(integer=True, min_n_ticks=1)) diff --git a/b_asic/schedule.py b/b_asic/schedule.py index e4d6734d..8525a937 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -462,7 +462,7 @@ class Schedule: # if updating the scheduling time -> update laps due to operations # reading and writing in different iterations (across the edge) if self._schedule_time is not None: - for signal_id in self._laps.keys(): + for signal_id in self._laps: port = self._sfg.find_by_id(signal_id).destination source_port = port.signals[0].source @@ -701,7 +701,7 @@ class Schedule: else: offset += 1 - for gid, y_location in self._y_locations.items(): + for gid in self._y_locations: self._y_locations[gid] = remapping[self._y_locations[gid]] def get_y_location(self, graph_id: GraphID) -> int: @@ -992,7 +992,7 @@ class Schedule: destination_laps.append((port.operation.graph_id, port.index, lap)) for op, port, lap in destination_laps: - for delays in range(lap): + for _ in range(lap): new_sfg = new_sfg.insert_operation_before(op, Delay(), port) return new_sfg() @@ -1234,7 +1234,7 @@ class Schedule: latency_coordinates, execution_time_coordinates, ) = operation.get_plot_coordinates() - _x, _y = zip(*latency_coordinates) + _x, _y = zip(*latency_coordinates, strict=True) x = np.array(_x) y = np.array(_y) xvalues = x + op_start_time @@ -1258,7 +1258,7 @@ class Schedule: size=10 - (0.05 * len(self._start_times)), ) if execution_time_coordinates: - _x, _y = zip(*execution_time_coordinates) + _x, _y = zip(*execution_time_coordinates, strict=True) x = np.array(_x) y = np.array(_y) xvalues = x + op_start_time diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index f2b36b74..6ab4057f 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -164,7 +164,7 @@ class ALAPScheduler(Scheduler): # adjust the scheduling time if empty time slots have appeared in the start slack = min(schedule.start_times.values()) - for op_id in schedule.start_times.keys(): + for op_id in schedule.start_times: schedule.move_operation(op_id, -slack) schedule.set_schedule_time(schedule._schedule_time - slack) @@ -349,7 +349,7 @@ class ListScheduler(Scheduler): def _calculate_fan_outs(self) -> dict["GraphID", int]: return { op_id: len(self._sfg.find_by_id(op_id).output_signals) - for op_id in self._alap_start_times.keys() + for op_id in self._alap_start_times } def _calculate_memory_reads( @@ -379,13 +379,13 @@ class ListScheduler(Scheduler): if other_op_id != op._graph_id: if self._schedule._schedule_time is not None: start_time = start_time % self._schedule._schedule_time - - if time >= start_time: - if time < start_time + max( - self._cached_execution_times[other_op_id], 1 - ): - if isinstance(self._sfg.find_by_id(other_op_id), type(op)): - count += 1 + if ( + time >= start_time + and time + < start_time + max(self._cached_execution_times[other_op_id], 1) + and isinstance(self._sfg.find_by_id(other_op_id), type(op)) + ): + count += 1 return count def _op_satisfies_resource_constraints(self, op: "Operation") -> bool: @@ -446,7 +446,7 @@ class ListScheduler(Scheduler): tmp_used_reads = {} for i, op_input in enumerate(op.inputs): source_op = op_input.signals[0].source.operation - if isinstance(source_op, Delay) or isinstance(source_op, DontCare): + if isinstance(source_op, (Delay, DontCare)): continue if ( self._schedule.start_times[source_op.graph_id] @@ -477,7 +477,7 @@ class ListScheduler(Scheduler): source_port = op_input.signals[0].source source_op = source_port.operation - if isinstance(source_op, Delay) or isinstance(source_op, DontCare): + if isinstance(source_op, (Delay, DontCare)): continue if source_op.graph_id in self._remaining_ops: @@ -519,7 +519,7 @@ class ListScheduler(Scheduler): self._schedule = schedule self._sfg = schedule._sfg - for resource_type in self._max_resources.keys(): + for resource_type in self._max_resources: if not self._sfg.find_by_type_name(resource_type): raise ValueError( f"Provided max resource of type {resource_type} cannot be found in the provided SFG." @@ -528,7 +528,7 @@ class ListScheduler(Scheduler): differing_elems = [ resource for resource in self._sfg.get_used_type_names() - if resource not in self._max_resources.keys() + if resource not in self._max_resources and resource != Delay.type_name() and resource != DontCare.type_name() and resource != Sink.type_name() @@ -536,13 +536,13 @@ class ListScheduler(Scheduler): for type_name in differing_elems: self._max_resources[type_name] = 1 - for key in self._input_times.keys(): + for key in self._input_times: if self._sfg.find_by_id(key) is None: raise ValueError( f"Provided input time with GraphID {key} cannot be found in the provided SFG." ) - for key in self._output_delta_times.keys(): + for key in self._output_delta_times: if self._sfg.find_by_id(key) is None: raise ValueError( f"Provided output delta time with GraphID {key} cannot be found in the provided SFG." @@ -574,16 +574,19 @@ class ListScheduler(Scheduler): self._alap_op_laps = alap_scheduler.op_laps self._alap_schedule_time = alap_schedule._schedule_time self._schedule.start_times = {} - for key in self._schedule._laps.keys(): + for key in self._schedule._laps: self._schedule._laps[key] = 0 - if not self._schedule._cyclic and self._schedule._schedule_time: - if alap_schedule._schedule_time > self._schedule._schedule_time: - raise ValueError( - f"Provided scheduling time {schedule._schedule_time} cannot be reached, " - "try to enable the cyclic property or increase the time to at least " - f"{alap_schedule._schedule_time}." - ) + if ( + not self._schedule._cyclic + and self._schedule.schedule_time + and alap_schedule.schedule_time > self._schedule.schedule_time + ): + raise ValueError( + f"Provided scheduling time {schedule.schedule_time} cannot be reached, " + "try to enable the cyclic property or increase the time to at least " + f"{alap_schedule.schedule_time}." + ) self._remaining_resources = self._max_resources.copy() @@ -753,13 +756,13 @@ class ListScheduler(Scheduler): if ( not self._schedule._cyclic and self._schedule._schedule_time is not None + and new_time > self._schedule._schedule_time ): - if new_time > self._schedule._schedule_time: - raise ValueError( - f"Cannot place output {output.graph_id} at time {new_time} " - f"for scheduling time {self._schedule._schedule_time}. " - "Try to relax the scheduling time, change the output delta times or enable cyclic." - ) + raise ValueError( + f"Cannot place output {output.graph_id} at time {new_time} " + f"for scheduling time {self._schedule._schedule_time}. " + "Try to relax the scheduling time, change the output delta times or enable cyclic." + ) self._logger.debug( f" {output.graph_id} moved {min_slack} time steps backwards to new time {new_time}" ) @@ -818,20 +821,19 @@ class RecursiveListScheduler(ListScheduler): def _get_recursive_ops(self, loops: list[list["GraphID"]]) -> list["GraphID"]: recursive_ops = [] - seen = [] for loop in loops: for op_id in loop: - if op_id not in seen: - if not isinstance(self._sfg.find_by_id(op_id), Delay): - recursive_ops.append(op_id) - seen.append(op_id) + if op_id not in recursive_ops and not isinstance( + self._sfg.find_by_id(op_id), Delay + ): + recursive_ops.append(op_id) return recursive_ops def _recursive_op_satisfies_data_dependencies(self, op: "Operation") -> bool: - for input_port_index, op_input in enumerate(op.inputs): + for op_input in op.inputs: source_port = source_op = op_input.signals[0].source source_op = source_port.operation - if isinstance(source_op, Delay) or isinstance(source_op, DontCare): + if isinstance(source_op, (Delay, DontCare)): continue if ( source_op.graph_id in self._recursive_ops @@ -937,7 +939,7 @@ class RecursiveListScheduler(ListScheduler): for op_input in op.inputs: source_port = op_input.signals[0].source source_op = source_port.operation - if isinstance(source_op, Delay) or isinstance(source_op, DontCare): + if isinstance(source_op, (Delay, DontCare)): continue if source_op.graph_id in self._remaining_ops: return False diff --git a/b_asic/scheduler_gui/compile.py b/b_asic/scheduler_gui/compile.py index 922b6ecf..c88ec3e6 100644 --- a/b_asic/scheduler_gui/compile.py +++ b/b_asic/scheduler_gui/compile.py @@ -177,11 +177,7 @@ def compile_ui(*filenames: str) -> None: ) os_ = sys.platform - if os_.startswith("linux"): # Linux - cmd = f"{uic_} {arguments}" - subprocess.call(cmd.split()) - - elif os_.startswith("win32"): # Windows + if os_.startswith("linux") or os_.startswith("win32"): cmd = f"{uic_} {arguments}" subprocess.call(cmd.split()) @@ -190,7 +186,7 @@ def compile_ui(*filenames: str) -> None: log.error("macOS UI compiler not implemented") raise NotImplementedError - else: # other OS + else: log.error(f"{os_} UI compiler not supported") raise NotImplementedError diff --git a/b_asic/scheduler_gui/main_window.py b/b_asic/scheduler_gui/main_window.py index a39aacf4..63ae3ff6 100644 --- a/b_asic/scheduler_gui/main_window.py +++ b/b_asic/scheduler_gui/main_window.py @@ -976,7 +976,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): log.error("'Operator' not found in info table. It may have been renamed.") def _create_recent_file_actions_and_menus(self): - for i in range(self._max_recent_files): + for _ in range(self._max_recent_files): recent_file_action = QAction(self.menu_Recent_Schedule) recent_file_action.setVisible(False) recent_file_action.triggered.connect( @@ -1234,12 +1234,12 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): elif not LATENCY_COLOR_TYPE.changed and not self._color_changed_perType: self._color_per_type[type] = LATENCY_COLOR_TYPE.DEFAULT elif not LATENCY_COLOR_TYPE.changed and self._color_changed_perType: - if type in self.changed_operation_colors.keys(): + if type in self.changed_operation_colors: self._color_per_type[type] = self.changed_operation_colors[type] else: self._color_per_type[type] = LATENCY_COLOR_TYPE.DEFAULT else: - if type in self.changed_operation_colors.keys(): + if type in self.changed_operation_colors: self._color_per_type[type] = self.changed_operation_colors[type] else: self._color_per_type[type] = LATENCY_COLOR_TYPE.current_color @@ -1287,10 +1287,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): # If a valid color is selected, update the current color and settings if color.isValid(): color_type.current_color = color - # colorbutton.set_color(color) - color_type.changed = ( - False if color_type.current_color == color_type.DEFAULT else True - ) + color_type.changed = color_type.current_color != color_type.DEFAULT settings.setValue(f"scheduler/preferences/{color_type.name}", color.name()) settings.sync() @@ -1312,10 +1309,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): boldbutton : ColorButton The button displaying the bold state to be matched with the chosen font. """ - if FONT.changed: - current_font = FONT.current_font - else: - current_font = FONT.DEFAULT + current_font = FONT.current_font if FONT.changed else FONT.DEFAULT (ok, font) = QFontDialog.getFont(current_font, self) if ok: @@ -1330,15 +1324,11 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): def update_font(self): """Update font preferences based on current Font settings""" settings = QSettings() - FONT.changed = ( - False - if ( - FONT.current_font == FONT.DEFAULT - and FONT.size == int(FONT.DEFAULT.pointSizeF()) - and FONT.italic == FONT.DEFAULT.italic() - and FONT.bold == FONT.DEFAULT.bold() - ) - else True + FONT.changed = not ( + FONT.current_font == FONT.DEFAULT + and FONT.size == int(FONT.DEFAULT.pointSizeF()) + and FONT.italic == FONT.DEFAULT.italic() + and FONT.bold == FONT.DEFAULT.bold() ) settings.setValue("scheduler/preferences/font", FONT.current_font.toString()) settings.setValue("scheduler/preferences/fontSize", FONT.size) @@ -1411,7 +1401,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): size The font size to be set. """ - FONT.size = int(size) if (not size == "") else 6 + FONT.size = int(size) if size != "" else 6 FONT.current_font.setPointSizeF(FONT.size) self.update_font() diff --git a/b_asic/scheduler_gui/scheduler_event.py b/b_asic/scheduler_gui/scheduler_event.py index 872b4511..b191fa50 100644 --- a/b_asic/scheduler_gui/scheduler_event.py +++ b/b_asic/scheduler_gui/scheduler_event.py @@ -210,14 +210,16 @@ class SchedulerEvent: if pos_x < 0: pos_x += self._schedule.schedule_time redraw = True - if pos_x > self._schedule.schedule_time: + if ( + pos_x > self._schedule.schedule_time # If zero execution time, keep operation at the edge - if ( + and ( pos_x > self._schedule.schedule_time + 1 or item.operation.execution_time - ): - pos_x = pos_x % self._schedule.schedule_time - redraw = True + ) + ): + pos_x = pos_x % self._schedule.schedule_time + redraw = True pos_y = self._schedule.get_y_location(item.operation.graph_id) # Check move in y-direction if pos_y != self._old_op_position: diff --git a/b_asic/scheduler_gui/scheduler_item.py b/b_asic/scheduler_gui/scheduler_item.py index 1bae4d31..2097e295 100644 --- a/b_asic/scheduler_gui/scheduler_item.py +++ b/b_asic/scheduler_gui/scheduler_item.py @@ -149,10 +149,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): def _color_change(self, color: QColor, name: str) -> None: """Change inactive color of operation item *.""" for op in self.components: - if name == "all operations": - op._set_background(color) - op._inactive_color = color - elif name == op.operation.type_name(): + if name in ("all operations", op.operation.type_name()): op._set_background(color) op._inactive_color = color @@ -327,7 +324,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): def _redraw_from_start(self) -> None: self.schedule.reset_y_locations() self.schedule.sort_y_locations_on_start_times() - for graph_id in self.schedule.start_times.keys(): + for graph_id in self.schedule.start_times: self._set_position(graph_id) self._redraw_all_lines() self._update_axes() @@ -355,7 +352,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): def _make_graph(self) -> None: """Make a new graph out of the stored attributes.""" # build components - for graph_id in self.schedule.start_times.keys(): + for graph_id in self.schedule.start_times: operation = cast(Operation, self.schedule._sfg.find_by_id(graph_id)) component = OperationItem(operation, height=OPERATION_HEIGHT, parent=self) self._operation_items[graph_id] = component diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py index a6d4db39..faf79d2b 100644 --- a/b_asic/sfg_generators.py +++ b/b_asic/sfg_generators.py @@ -415,7 +415,7 @@ def radix_2_dif_fft(points: int) -> SFG: raise ValueError("Points must be a power of two.") inputs = [] - for i in range(points): + for _ in range(points): inputs.append(Input()) ports = inputs @@ -430,7 +430,7 @@ def radix_2_dif_fft(points: int) -> SFG: ports = _get_bit_reversed_ports(ports) outputs = [] - for i, port in enumerate(ports): + for port in ports: outputs.append(Output(port)) return SFG(inputs=inputs, outputs=outputs) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index dace2bca..bf07f04a 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -287,7 +287,9 @@ class SFG(AbstractOperation): new_signal = cast(Signal, self._original_components_to_new[signal]) if new_signal.source in output_sources: - warnings.warn("Two signals connected to the same output port") + warnings.warn( + "Two signals connected to the same output port", stacklevel=2 + ) output_sources.append(new_signal.source) if new_signal.source is None: @@ -377,6 +379,7 @@ class SFG(AbstractOperation): if quantize else input_values ), + strict=True, ): op.value = arg @@ -434,7 +437,9 @@ class SFG(AbstractOperation): return False # For each input_signal, connect it to the corresponding operation - for input_port, input_operation in zip(self.inputs, self.input_operations): + for input_port, input_operation in zip( + self.inputs, self.input_operations, strict=True + ): destination = input_operation.output(0).signals[0].destination if destination is None: raise ValueError("Missing destination in signal.") @@ -448,7 +453,9 @@ class SFG(AbstractOperation): other_destination.add_signal(Signal(destination.signals[0].source)) input_operation.output(0).clear() # For each output_signal, connect it to the corresponding operation - for output_port, output_operation in zip(self.outputs, self.output_operations): + for output_port, output_operation in zip( + self.outputs, self.output_operations, strict=True + ): src = output_operation.input(0).signals[0].source if src is None: raise ValueError("Missing source in signal.") @@ -505,10 +512,9 @@ class SFG(AbstractOperation): visited: set[Operation] = {output_op} while queue: op = queue.popleft() - if isinstance(op, Input): - if op in sfg_input_operations_to_indexes: - input_indexes_required.append(sfg_input_operations_to_indexes[op]) - del sfg_input_operations_to_indexes[op] + if isinstance(op, Input) and op in sfg_input_operations_to_indexes: + input_indexes_required.append(sfg_input_operations_to_indexes[op]) + del sfg_input_operations_to_indexes[op] for input_port in op.inputs: for signal in input_port.signals: @@ -649,7 +655,9 @@ class SFG(AbstractOperation): if component_copy.type_name() == 'out': sfg_copy._output_operations.remove(component_copy) - warnings.warn(f"Output port {component_copy.graph_id} has been removed") + warnings.warn( + f"Output port {component_copy.graph_id} has been removed", stacklevel=2 + ) if component.type_name() == 'out': sfg_copy._output_operations.append(component) @@ -1007,7 +1015,7 @@ class SFG(AbstractOperation): ) # Creates edges for each output port and creates nodes for each operation # and edges for them as well - for i, ports in enumerate(p_list): + for ports in p_list: for port in ports: source_label = port.operation.graph_id node_node = port.name @@ -1774,7 +1782,7 @@ class SFG(AbstractOperation): op_and_latency = {} for op in self.operations: for loop in loops: - for element in loop: + for _ in loop: if op.type_name() not in op_and_latency: op_and_latency[op.type_name()] = op.latency t_l_values = [] @@ -1821,14 +1829,15 @@ class SFG(AbstractOperation): while queue: op = queue.popleft() for output_port in op.outputs: - if not (isinstance(op, Input) or isinstance(op, Output)): + if not isinstance(op, (Input, Output)): dict_of_sfg[op.graph_id] = [] for signal in output_port.signals: if signal.destination is not None: new_op = signal.destination.operation - if not (isinstance(op, Input) or isinstance(op, Output)): - if not isinstance(new_op, Output): - dict_of_sfg[op.graph_id] += [new_op.graph_id] + if not isinstance(op, (Input, Output)) and not isinstance( + new_op, Output + ): + dict_of_sfg[op.graph_id] += [new_op.graph_id] if new_op not in visited: queue.append(new_op) visited.add(new_op) @@ -2036,7 +2045,9 @@ class SFG(AbstractOperation): mat_content[row, column] += temp_value return matrix_answer, mat_content, matrix_in - def find_all_paths(self, graph: dict, start: str, end: str, path=[]) -> list: + def find_all_paths( + self, graph: dict, start: str, end: str, path: list | None = None + ) -> list: """ Returns all paths in graph from node start to node end @@ -2053,6 +2064,9 @@ class SFG(AbstractOperation): ------- The state-space representation of the SFG. """ + if path is None: + path = [] + path = path + [start] if start == end: return [path] @@ -2137,7 +2151,9 @@ class SFG(AbstractOperation): # For each copy of the SFG, create new input operations for every "original" # input operation and connect them to begin creating the unfolded SFG for i in range(factor): - for port, operation in zip(sfgs[i].inputs, sfgs[i].input_operations): + for port, operation in zip( + sfgs[i].inputs, sfgs[i].input_operations, strict=True + ): if not operation.name.startswith("input_t"): i = Input() new_inputs.append(i) @@ -2154,7 +2170,9 @@ class SFG(AbstractOperation): new_outputs = [] delay_placements = {} for i in range(factor): - for port, operation in zip(sfgs[i].outputs, sfgs[i].output_operations): + for port, operation in zip( + sfgs[i].outputs, sfgs[i].output_operations, strict=True + ): if not operation.name.startswith("output_t"): new_outputs.append(Output(port)) else: diff --git a/b_asic/utils.py b/b_asic/utils.py index 654d86a9..f2a8b630 100644 --- a/b_asic/utils.py +++ b/b_asic/utils.py @@ -26,7 +26,7 @@ def interleave(*args) -> list[Num]: ... interleave(a, b, c) [1, 3, -1, 2, 4, 0] """ - return [val for tup in zip(*args) for val in tup] + return [val for tup in zip(*args, strict=True) for val in tup] def downsample(a: Sequence[Num], factor: int, phase: int = 0) -> list[Num]: diff --git a/pyproject.toml b/pyproject.toml index 58dc2c51..c1ab97e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,11 @@ ignore_missing_imports = true precision = 2 [tool.ruff] -lint.ignore = ["F403"] +exclude = ["examples"] + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "SIM", "B"] +ignore = ["F403", "B008", "B021", "B006"] [tool.typos] default.extend-identifiers = { ba = "ba", addd0 = "addd0", inout = "inout", ArChItEctUrE = "ArChItEctUrE" } diff --git a/test/unit/test_list_schedulers.py b/test/unit/test_list_schedulers.py index beba875f..7821a2bb 100644 --- a/test/unit/test_list_schedulers.py +++ b/test/unit/test_list_schedulers.py @@ -1853,7 +1853,7 @@ class TestRecursiveListScheduler: _validate_recreated_sfg_filter(sfg, schedule) def test_large_direct_form_2_iir(self): - N = 10 + N = 8 Wc = 0.2 b, a = signal.butter(N, Wc, btype="lowpass", output="ba") sfg = direct_form_2_iir(b, a) diff --git a/test/unit/test_sfg.py b/test/unit/test_sfg.py index 090a11dd..7dd05ab0 100644 --- a/test/unit/test_sfg.py +++ b/test/unit/test_sfg.py @@ -308,16 +308,16 @@ class TestReplaceOperation: component_id = "add0" sfg = sfg.replace_operation(Multiplication(name="Multi"), graph_id=component_id) - assert component_id not in sfg._components_by_id.keys() - assert "Multi" in sfg._components_by_name.keys() + assert component_id not in sfg._components_by_id + assert "Multi" in sfg._components_by_name def test_replace_addition_large_tree(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "add2" sfg = sfg.replace_operation(Multiplication(name="Multi"), graph_id=component_id) - assert "Multi" in sfg._components_by_name.keys() - assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name + assert component_id not in sfg._components_by_id def test_replace_no_input_component(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) @@ -1514,7 +1514,7 @@ class TestUnfold: # Ensure that we aren't missing any keys, or have any extras assert count1.keys() == count2.keys() - for k in count1.keys(): + for k in count1: assert count1[k] * multiple == count2[k] # This is horrifying, but I can't figure out a way to run the test on multiple @@ -1818,8 +1818,8 @@ class TestResourceLowerBound: assert sfg.resource_lower_bound("cmul", 1000) == 0 def test_type_not_in_sfg(self, sfg_simple_accumulator): - sfg_simple_accumulator.resource_lower_bound("bfly", 2) == 0 - sfg_simple_accumulator.resource_lower_bound("bfly", 1000) == 0 + assert sfg_simple_accumulator.resource_lower_bound("bfly", 2) == 0 + assert sfg_simple_accumulator.resource_lower_bound("bfly", 1000) == 0 def test_negative_schedule_time(self, precedence_sfg_delays): precedence_sfg_delays.set_latency_of_type("add", 2) diff --git a/test/unit/test_signal.py b/test/unit/test_signal.py index ca2c57f6..469715e6 100644 --- a/test/unit/test_signal.py +++ b/test/unit/test_signal.py @@ -160,4 +160,4 @@ def test_signal_errors(): signal = Signal() with pytest.raises(ValueError, match="Signal source not set"): - signal.is_constant + _ = signal.is_constant -- GitLab