Skip to content
Snippets Groups Projects
Commit 5a905f37 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Fix some typing

parent 722749dd
No related branches found
No related tags found
1 merge request!225Fix some typing
Pipeline #90267 passed
......@@ -120,7 +120,7 @@ class ProcessCollection:
return self._collection
def __len__(self):
return len(self.__collection__)
return len(self._collection)
def add_process(self, process: Process):
"""
......
......@@ -79,8 +79,8 @@ class Schedule:
schedule_time: Optional[int] = None,
cyclic: bool = False,
scheduling_algorithm: str = "ASAP",
start_times: Dict[GraphID, int] = None,
laps: Dict[GraphID, int] = None,
start_times: Optional[Dict[GraphID, int]] = None,
laps: Optional[Dict[GraphID, int]] = None,
):
"""Construct a Schedule from an SFG."""
self._original_sfg = sfg() # Make a copy
......@@ -92,6 +92,10 @@ class Schedule:
if scheduling_algorithm == "ASAP":
self._schedule_asap()
elif scheduling_algorithm == "provided":
if start_times is None:
raise ValueError("Must provide start_times when using 'provided'")
if laps is None:
raise ValueError("Must provide laps when using 'provided'")
self._start_times = start_times
self._laps.update(laps)
self._remove_delays_no_laps()
......@@ -403,10 +407,10 @@ class Schedule:
"""
if insert:
for gid, y_location in self._y_locations.items():
if y_location >= new_y:
self._y_locations[gid] += 1
self._y_locations[graph_id] = new_y
for gid in self._y_locations:
if self.get_y_location(gid) >= new_y:
self.set_y_location(gid, self.get_y_location(gid) + 1)
self.set_y_location(graph_id, new_y)
used_locations = {*self._y_locations.values()}
possible_locations = set(range(max(used_locations) + 1))
if not possible_locations - used_locations:
......@@ -889,7 +893,7 @@ class Schedule:
def _reset_y_locations(self) -> None:
"""Reset all the y-locations in the schedule to None"""
self._y_locations = self._y_locations = defaultdict(lambda: None)
self._y_locations = defaultdict(lambda: None)
def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
"""
......
import pickle
import matplotlib.pyplot as plt
import networkx as nx
import pytest
from b_asic.process import Process
from b_asic.research.interleaver import (
generate_matrix_transposer,
generate_random_interleaver,
)
from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring
from b_asic.resources import ProcessCollection
class TestProcessCollectionPlainMemoryVariable:
......@@ -44,3 +42,6 @@ class TestProcessCollectionPlainMemoryVariable:
assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1
if any(var.execution_time for var in collection.collection):
assert len(collection.split_ports(total_ports=1)) == 2
def test_len_process_collection(self, simple_collection: ProcessCollection):
assert len(simple_collection) == 7
......@@ -495,6 +495,8 @@ class TestProcesses:
def test__get_memory_variables_list(self, secondorder_iir_schedule):
mvl = secondorder_iir_schedule._get_memory_variables_list()
assert len(mvl) == 12
pc = secondorder_iir_schedule.get_memory_variables()
assert len(pc) == 12
class TestFigureGeneration:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment