Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • da/B-ASIC
  • lukja239/B-ASIC
  • robal695/B-ASIC
3 results
Show changes
Commits on Source (3)
......@@ -25,9 +25,7 @@ _T = TypeVar('_T')
def _sorted_nicely(to_be_sorted: Iterable[_T]) -> List[_T]:
"""Sort the given iterable in the way that humans expect."""
convert = lambda text: int(text) if text.isdigit() else text
alphanum_key = lambda key: [
convert(c) for c in re.split('([0-9]+)', str(key))
]
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', str(key))]
return sorted(to_be_sorted, key=alphanum_key)
......@@ -35,9 +33,7 @@ def draw_exclusion_graph_coloring(
exclusion_graph: nx.Graph,
color_dict: Dict[Process, int],
ax: Optional[Axes] = None,
color_list: Optional[
Union[List[str], List[Tuple[float, float, float]]]
] = None,
color_list: Optional[Union[List[str], List[Tuple[float, float, float]]]] = None,
):
"""
Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment
......@@ -119,6 +115,13 @@ class ProcessCollection:
self._schedule_time = schedule_time
self._cyclic = cyclic
@property
def collection(self):
return self._collection
def __len__(self):
return len(self.__collection__)
def add_process(self, process: Process):
"""
Add a new process to this process collection.
......@@ -174,9 +177,7 @@ class ProcessCollection:
# Lifetime chart left and right padding
PAD_L, PAD_R = 0.05, 0.05
max_execution_time = max(
process.execution_time for process in self._collection
)
max_execution_time = max(process.execution_time for process in self._collection)
if max_execution_time > self._schedule_time:
# Schedule time needs to be greater than or equal to the maximum process lifetime
raise KeyError(
......@@ -187,10 +188,12 @@ class ProcessCollection:
# Generate the life-time chart
for i, process in enumerate(_sorted_nicely(self._collection)):
bar_start = process.start_time % self._schedule_time
bar_end = process.start_time + process.execution_time
bar_end = (
process.start_time + process.execution_time
) % self._schedule_time
bar_end = self._schedule_time if bar_end == 0 else bar_end
bar_end
if bar_end == self._schedule_time
else bar_end % self._schedule_time
)
if show_markers:
_ax.scatter(
x=bar_start,
......@@ -240,16 +243,84 @@ class ProcessCollection:
_ax.set_ylim(0.25, len(self._collection) + 0.75)
return _ax
def create_exclusion_graph_from_overlap(
self, add_name: bool = True
def create_exclusion_graph_from_ports(
self,
read_ports: Optional[int] = None,
write_ports: Optional[int] = None,
total_ports: Optional[int] = None,
) -> nx.Graph:
"""
Generate exclusion graph based on processes overlapping in time
Create an exclusion graph from a ProcessCollection based on a number of read/write ports
Parameters
----------
add_name : bool, default: True
Add name of all processes as a node attribute in the exclusion graph.
read_ports : int
The number of read ports used when splitting process collection based on memory variable access.
write_ports : int
The number of write ports used when splitting process collection based on memory variable access.
total_ports : int
The total number of ports used when splitting process collection based on memory variable access.
Returns
-------
nx.Graph
"""
if total_ports is None:
if read_ports is None or write_ports is None:
raise ValueError(
"If total_ports is unset, both read_ports and write_ports"
" must be provided."
)
else:
total_ports = read_ports + write_ports
else:
read_ports = total_ports if read_ports is None else read_ports
write_ports = total_ports if write_ports is None else write_ports
# Guard for proper read/write port settings
if read_ports != 1 or write_ports != 1:
raise ValueError(
"Splitting with read and write ports not equal to one with the"
" graph coloring heuristic does not make sense."
)
if total_ports not in (1, 2):
raise ValueError(
"Total ports should be either 1 (non-concurrent reads/writes)"
" or 2 (concurrent read/writes) for graph coloring heuristic."
)
# Create new exclusion graph. Nodes are Processes
exclusion_graph = nx.Graph()
exclusion_graph.add_nodes_from(self._collection)
for node1 in exclusion_graph:
for node2 in exclusion_graph:
if node1 == node2:
continue
else:
node1_stop_time = node1.start_time + node1.execution_time
node2_stop_time = node2.start_time + node2.execution_time
if total_ports == 1:
# Single-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1.start_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
else:
# Dual-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
return exclusion_graph
def create_exclusion_graph_from_execution_time(self) -> nx.Graph:
"""
Generate exclusion graph based on processes overlapping in time
Returns
-------
......@@ -279,7 +350,47 @@ class ProcessCollection:
exclusion_graph.add_edge(process1, process2)
return exclusion_graph
def split(
def split_execution_time(
self, heuristic: str = "graph_color", coloring_strategy: str = "DSATUR"
) -> Set["ProcessCollection"]:
"""
Split a ProcessCollection based on overlapping execution time.
Parameters
----------
heuristic : str, default: 'graph_color'
The heuristic used when splitting based on execution times.
One of: 'graph_color', 'left_edge'.
coloring_strategy: str, default: 'DSATUR'
Node ordering strategy passed to nx.coloring.greedy_color() if the heuristic is set to 'graph_color'. This
parameter is only considered if heuristic is set to graph_color.
One of
* `'largest_first'`
* `'random_sequential'`
* `'smallest_last'`
* `'independent_set'`
* `'connected_sequential_bfs'`
* `'connected_sequential_dfs'`
* `'connected_sequential'` (alias for the previous strategy)
* `'saturation_largest_first'`
* `'DSATUR'` (alias for the saturation_largest_first strategy)
Returns
-------
A set of new ProcessCollection objects with the process splitting.
"""
if heuristic == "graph_color":
exclusion_graph = self.create_exclusion_graph_from_execution_time()
coloring = nx.coloring.greedy_color(
exclusion_graph, strategy=coloring_strategy
)
return self._split_from_graph_coloring(coloring)
elif heuristic == "left_edge":
raise NotImplementedError()
else:
raise ValueError(f"Invalid heuristic '{heuristic}'")
def split_ports(
self,
heuristic: str = "graph_color",
read_ports: Optional[int] = None,
......@@ -309,86 +420,80 @@ class ProcessCollection:
"""
if total_ports is None:
if read_ports is None or write_ports is None:
raise ValueError("inteligent quote")
raise ValueError(
"If total_ports is unset, both read_ports and write_ports"
" must be provided."
)
else:
total_ports = read_ports + write_ports
else:
read_ports = total_ports if read_ports is None else read_ports
write_ports = total_ports if write_ports is None else write_ports
if heuristic == "graph_color":
return self._split_graph_color(
read_ports, write_ports, total_ports
)
return self._split_ports_graph_color(read_ports, write_ports, total_ports)
else:
raise ValueError("Invalid heuristic provided")
raise ValueError("Invalid heuristic provided.")
def _split_graph_color(
self, read_ports: int, write_ports: int, total_ports: int
def _split_ports_graph_color(
self,
read_ports: int,
write_ports: int,
total_ports: int,
coloring_strategy: str = "DSATUR",
) -> Set["ProcessCollection"]:
"""
Parameters
----------
read_ports : int, optional
read_ports : int
The number of read ports used when splitting process collection based on memory variable access.
write_ports : int, optional
write_ports : int
The number of write ports used when splitting process collection based on memory variable access.
total_ports : int, optional
total_ports : int
The total number of ports used when splitting process collection based on memory variable access.
coloring_strategy: str, default: 'DSATUR'
Node ordering strategy passed to nx.coloring.greedy_color()
One of
* `'largest_first'`
* `'random_sequential'`
* `'smallest_last'`
* `'independent_set'`
* `'connected_sequential_bfs'`
* `'connected_sequential_dfs'`
* `'connected_sequential'` (alias for the previous strategy)
* `'saturation_largest_first'`
* `'DSATUR'` (alias for the saturation_largest_first strategy)
"""
if read_ports != 1 or write_ports != 1:
raise ValueError(
"Splitting with read and write ports not equal to one with the"
" graph coloring heuristic does not make sense."
)
if total_ports not in (1, 2):
raise ValueError(
"Total ports should be either 1 (non-concurrent reads/writes)"
" or 2 (concurrent read/writes) for graph coloring heuristic."
)
# Create new exclusion graph. Nodes are Processes
exclusion_graph = nx.Graph()
exclusion_graph.add_nodes_from(self._collection)
exclusion_graph = self.create_exclusion_graph_from_ports(
read_ports, write_ports, total_ports
)
# Add exclusions (arcs) between processes in the exclusion graph
for node1 in exclusion_graph:
for node2 in exclusion_graph:
if node1 == node2:
continue
else:
node1_stop_time = node1.start_time + node1.execution_time
node2_stop_time = node2.start_time + node2.execution_time
if total_ports == 1:
# Single-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1.start_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
else:
# Dual-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
# Perform assignment from coloring and return result
coloring = nx.coloring.greedy_color(exclusion_graph, strategy=coloring_strategy)
return self._split_from_graph_coloring(coloring)
# Perform assignment
coloring = nx.coloring.greedy_color(exclusion_graph)
draw_exclusion_graph_coloring(exclusion_graph, coloring)
# process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1)
process_collection_set_list = [
set() for _ in range(max(coloring.values()) + 1)
]
def _split_from_graph_coloring(
self,
coloring: Dict[Process, int],
) -> Set["ProcessCollection"]:
"""
Split :class:`Process` objects into a set of :class:`ProcessesCollection` objects based on a provided graph coloring.
Resulting :class:`ProcessCollection` will have the same schedule time and cyclic propoery as self.
Parameters
----------
coloring : Dict[Process, int]
Process->int (color) mappings
Returns
-------
A set of new ProcessCollections.
"""
process_collection_set_list = [set() for _ in range(max(coloring.values()) + 1)]
for process, color in coloring.items():
process_collection_set_list[color].add(process)
return {
ProcessCollection(
process_collection_set, self._schedule_time, self._cyclic
)
ProcessCollection(process_collection_set, self._schedule_time, self._cyclic)
for process_collection_set in process_collection_set_list
}
......@@ -403,3 +508,9 @@ class ProcessCollection:
fig.savefig(f, format="svg")
return f.getvalue()
def __repr__(self):
return (
f"ProcessCollection({self._collection}, {self._schedule_time},"
f" {self._cyclic})"
)
......@@ -34,7 +34,7 @@ from b_asic.port import InputPort, OutputPort
from b_asic.process import MemoryVariable, Process
from b_asic.resources import ProcessCollection
from b_asic.signal_flow_graph import SFG
from b_asic.special_operations import Delay, Output
from b_asic.special_operations import Delay, Input, Output
# Need RGB from 0 to 1
_EXECUTION_TIME_COLOR = tuple(c / 255 for c in EXECUTION_TIME_COLOR)
......@@ -91,9 +91,7 @@ class Schedule:
if schedule_time is None:
self._schedule_time = max_end_time
elif schedule_time < max_end_time:
raise ValueError(
f"Too short schedule time. Minimum is {max_end_time}."
)
raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.")
else:
self._schedule_time = schedule_time
......@@ -102,9 +100,7 @@ class Schedule:
Return the start time of the operation with the specified by *graph_id*.
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
return self._start_times[graph_id]
def get_max_end_time(self) -> int:
......@@ -139,9 +135,7 @@ class Schedule:
slacks
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
slack = sys.maxsize
output_slacks = self._forward_slacks(graph_id)
# Make more pythonic
......@@ -194,9 +188,7 @@ class Schedule:
slacks
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
slack = sys.maxsize
input_slacks = self._backward_slacks(graph_id)
# Make more pythonic
......@@ -205,9 +197,7 @@ class Schedule:
slack = min(slack, signal_slack)
return slack
def _backward_slacks(
self, graph_id: GraphID
) -> Dict[InputPort, Dict[Signal, int]]:
def _backward_slacks(self, graph_id: GraphID) -> Dict[InputPort, Dict[Signal, int]]:
ret = {}
start_time = self._start_times[graph_id]
op = cast(Operation, self._sfg.find_by_id(graph_id))
......@@ -250,9 +240,7 @@ class Schedule:
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
return self.backward_slack(graph_id), self.forward_slack(graph_id)
def print_slacks(self) -> None:
......@@ -309,13 +297,11 @@ class Schedule:
factor : int
The time resolution increment.
"""
self._start_times = {
k: factor * v for k, v in self._start_times.items()
}
self._start_times = {k: factor * v for k, v in self._start_times.items()}
for graph_id in self._start_times:
cast(
Operation, self._sfg.find_by_id(graph_id)
)._increase_time_resolution(factor)
cast(Operation, self._sfg.find_by_id(graph_id))._increase_time_resolution(
factor
)
self._schedule_time *= factor
return self
......@@ -366,13 +352,11 @@ class Schedule:
f"Not possible to decrease resolution with {factor}. Possible"
f" values are {possible_values}"
)
self._start_times = {
k: v // factor for k, v in self._start_times.items()
}
self._start_times = {k: v // factor for k, v in self._start_times.items()}
for graph_id in self._start_times:
cast(
Operation, self._sfg.find_by_id(graph_id)
)._decrease_time_resolution(factor)
cast(Operation, self._sfg.find_by_id(graph_id))._decrease_time_resolution(
factor
)
self._schedule_time = self._schedule_time // factor
return self
......@@ -388,9 +372,7 @@ class Schedule:
The time to move. If positive move forward, if negative move backward.
"""
if graph_id not in self._start_times:
raise ValueError(
f"No operation with graph_id {graph_id} in schedule"
)
raise ValueError(f"No operation with graph_id {graph_id} in schedule")
(backward_slack, forward_slack) = self.slacks(graph_id)
if not -backward_slack <= time <= forward_slack:
......@@ -413,15 +395,25 @@ class Schedule:
tmp_prev_available = tmp_usage - new_slack
prev_available = tmp_prev_available % self._schedule_time
laps = new_slack // self._schedule_time
source_op = signal.source.operation
if new_usage < prev_available:
print("Incrementing input laps 1")
laps += 1
if prev_available == 0 and new_usage == 0:
if (
prev_available == 0
and new_usage == 0
and (
tmp_prev_available > 0
or tmp_prev_available == 0
and not isinstance(source_op, Input)
)
):
print("Incrementing input laps 2")
laps += 1
print(
[
"Input",
signal.source.operation,
time,
tmp_start,
signal_slack,
......@@ -476,12 +468,8 @@ class Schedule:
while delay_list:
delay_op = cast(Delay, delay_list[0])
delay_input_id = delay_op.input(0).signals[0].graph_id
delay_output_ids = [
sig.graph_id for sig in delay_op.output(0).signals
]
self._sfg = cast(
SFG, self._sfg.remove_operation(delay_op.graph_id)
)
delay_output_ids = [sig.graph_id for sig in delay_op.output(0).signals]
self._sfg = cast(SFG, self._sfg.remove_operation(delay_op.graph_id))
for output_id in delay_output_ids:
self._laps[output_id] += 1 + self._laps[delay_input_id]
del self._laps[delay_input_id]
......@@ -520,21 +508,16 @@ class Schedule:
for inport in op.inputs:
if len(inport.signals) != 1:
raise ValueError(
"Error in scheduling, dangling input port"
" detected."
"Error in scheduling, dangling input port detected."
)
if inport.signals[0].source is None:
raise ValueError(
"Error in scheduling, signal with no source"
" detected."
"Error in scheduling, signal with no source detected."
)
source_port = inport.signals[0].source
source_end_time = None
if (
source_port.operation.graph_id
in non_schedulable_ops
):
if source_port.operation.graph_id in non_schedulable_ops:
source_end_time = 0
else:
source_op_time = self._start_times[
......@@ -559,12 +542,8 @@ class Schedule:
f" {inport.operation.graph_id} has no"
" latency-offset."
)
op_start_time_from_in = (
source_end_time - inport.latency_offset
)
op_start_time = max(
op_start_time, op_start_time_from_in
)
op_start_time_from_in = source_end_time - inport.latency_offset
op_start_time = max(op_start_time, op_start_time_from_in)
self._start_times[op.graph_id] = op_start_time
for output in self._sfg.find_by_type_name(Output.type_name()):
......@@ -625,17 +604,13 @@ class Schedule:
y_location = self._y_locations[graph_id]
if y_location is None:
# Assign the lowest row number not yet in use
used = set(
loc for loc in self._y_locations.values() if loc is not None
)
used = set(loc for loc in self._y_locations.values() if loc is not None)
possible = set(range(len(self._start_times))) - used
y_location = min(possible)
self._y_locations[graph_id] = y_location
return operation_gap + y_location * (operation_height + operation_gap)
def _plot_schedule(
self, ax: Axes, operation_gap: Optional[float] = None
) -> None:
def _plot_schedule(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
"""Draw the schedule."""
line_cache = []
......@@ -722,9 +697,7 @@ class Schedule:
)
ax.add_patch(pp)
def _draw_offset_arrow(
start, end, start_offset, end_offset, name="", laps=0
):
def _draw_offset_arrow(start, end, start_offset, end_offset, name="", laps=0):
"""Draw an arrow from *start* to *end*, but with an offset."""
_draw_arrow(
[start[0] + start_offset[0], start[1] + start_offset[1]],
......@@ -761,24 +734,18 @@ class Schedule:
linewidth=3,
)
ytickpositions.append(y_pos + 0.5)
yticklabels.append(
cast(Operation, self._sfg.find_by_id(graph_id)).name
)
yticklabels.append(cast(Operation, self._sfg.find_by_id(graph_id)).name)
for graph_id, op_start_time in self._start_times.items():
op = cast(Operation, self._sfg.find_by_id(graph_id))
out_coordinates = op.get_output_coordinates()
source_y_pos = self._get_y_position(
graph_id, operation_gap=operation_gap
)
source_y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
for output_port in op.outputs:
for output_signal in output_port.signals:
destination = cast(InputPort, output_signal.destination)
destination_op = destination.operation
destination_start_time = self._start_times[
destination_op.graph_id
]
destination_start_time = self._start_times[destination_op.graph_id]
destination_y_pos = self._get_y_position(
destination_op.graph_id, operation_gap=operation_gap
)
......@@ -804,9 +771,7 @@ class Schedule:
+ 1
+ (OPERATION_GAP if operation_gap is None else operation_gap)
)
ax.axis(
[-1, self._schedule_time + 1, y_position_max, 0]
) # Inverted y-axis
ax.axis([-1, self._schedule_time + 1, y_position_max, 0]) # Inverted y-axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.axvline(
0,
......@@ -823,9 +788,7 @@ class Schedule:
"""Reset all the y-locations in the schedule to None"""
self._y_locations = self._y_locations = defaultdict(lambda: None)
def plot_in_axes(
self, ax: Axes, operation_gap: Optional[float] = None
) -> None:
def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
"""
Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass.
......
......@@ -54,14 +54,14 @@ documentation = "https://da.gitlab-pages.liu.se/B-ASIC/"
[tool.black]
skip-string-normalization = true
preview = true
line-length = 79
line-length = 88
exclude = [
"test/test_gui", "b_asic/scheduler_gui/ui_main_window.py"
]
[tool.isort]
profile = "black"
line_length = 79
line_length = 88
src_paths = ["b_asic", "test"]
skip = [
"test/test_gui", "b_asic/scheduler_gui/ui_main_window.py"
......
File added
import pickle
import matplotlib.pyplot as plt
import networkx as nx
import pytest
......@@ -6,7 +8,7 @@ from b_asic.research.interleaver import (
generate_matrix_transposer,
generate_random_interleaver,
)
from b_asic.resources import draw_exclusion_graph_coloring
from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring
class TestProcessCollectionPlainMemoryVariable:
......@@ -16,40 +18,28 @@ class TestProcessCollectionPlainMemoryVariable:
simple_collection.draw_lifetime_chart(ax=ax, show_markers=False)
return fig
def test_draw_proces_collection(self, simple_collection):
_, ax = plt.subplots(1, 2)
simple_collection.draw_lifetime_chart(ax=ax[0])
exclusion_graph = (
simple_collection.create_exclusion_graph_from_overlap()
)
color_dict = nx.coloring.greedy_color(exclusion_graph)
draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1])
def test_split_memory_variable(self, simple_collection):
collection_split = simple_collection.split(
read_ports=1, write_ports=1, total_ports=2
)
assert len(collection_split) == 3
@pytest.mark.mpl_image_compare(style='mpl20')
def test_draw_matrix_transposer_4(self):
fig, ax = plt.subplots()
generate_matrix_transposer(4).draw_lifetime_chart(ax=ax)
return fig
def test_split_memory_variable(self, simple_collection: ProcessCollection):
collection_split = simple_collection.split_ports(
heuristic="graph_color", read_ports=1, write_ports=1, total_ports=2
)
assert len(collection_split) == 3
# Issue: #175
def test_interleaver_issue175(self):
with open('test/fixtures/interleaver-two-port-issue175.p', 'rb') as f:
interleaver_collection: ProcessCollection = pickle.load(f)
assert len(interleaver_collection.split_ports(total_ports=1)) == 2
def test_generate_random_interleaver(self):
return
for _ in range(10):
for size in range(5, 20, 5):
assert (
len(
generate_random_interleaver(size).split(
read_ports=1, write_ports=1
)
)
== 1
)
assert (
len(generate_random_interleaver(size).split(total_ports=1))
== 2
)
collection = generate_random_interleaver(size)
assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1
if any(var.execution_time for var in collection.collection):
assert len(collection.split_ports(total_ports=1)) == 2
......@@ -8,15 +8,13 @@ import pytest
from b_asic.core_operations import Addition, Butterfly, ConstantMultiplication
from b_asic.schedule import Schedule
from b_asic.signal_flow_graph import SFG
from b_asic.special_operations import Input, Output
from b_asic.special_operations import Delay, Input, Output
class TestInit:
def test_simple_filter_normal_latency(self, sfg_simple_filter):
sfg_simple_filter.set_latency_of_type(Addition.type_name(), 5)
sfg_simple_filter.set_latency_of_type(
ConstantMultiplication.type_name(), 4
)
sfg_simple_filter.set_latency_of_type(ConstantMultiplication.type_name(), 4)
schedule = Schedule(sfg_simple_filter)
......@@ -28,13 +26,9 @@ class TestInit:
}
assert schedule.schedule_time == 9
def test_complicated_single_outputs_normal_latency(
self, precedence_sfg_delays
):
def test_complicated_single_outputs_normal_latency(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 4)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
......@@ -88,9 +82,7 @@ class TestInit:
}
assert secondorder_iir_schedule.schedule_time == 21
def test_complicated_single_outputs_complex_latencies(
self, precedence_sfg_delays
):
def test_complicated_single_outputs_complex_latencies(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_offsets_of_type(
ConstantMultiplication.type_name(), {"in0": 3, "out0": 5}
)
......@@ -152,9 +144,7 @@ class TestInit:
assert schedule.schedule_time == 17
def test_independent_sfg(
self, sfg_two_inputs_two_outputs_independent_with_cmul
):
def test_independent_sfg(self, sfg_two_inputs_two_outputs_independent_with_cmul):
schedule = Schedule(
sfg_two_inputs_two_outputs_independent_with_cmul,
scheduling_algorithm="ASAP",
......@@ -162,11 +152,9 @@ class TestInit:
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
start_times_names[op_name] = start_time
assert start_times_names == {
......@@ -184,13 +172,9 @@ class TestInit:
class TestSlacks:
def test_forward_backward_slack_normal_latency(
self, precedence_sfg_delays
):
def test_forward_backward_slack_normal_latency(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
assert (
......@@ -207,9 +191,7 @@ class TestSlacks:
)
assert (
schedule.forward_slack(
precedence_sfg_delays.find_by_name("A2")[0].graph_id
)
schedule.forward_slack(precedence_sfg_delays.find_by_name("A2")[0].graph_id)
== 0
)
assert (
......@@ -221,9 +203,7 @@ class TestSlacks:
def test_slacks_normal_latency(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
assert schedule.slacks(
......@@ -237,18 +217,14 @@ class TestSlacks:
class TestRescheduling:
def test_move_operation(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 4)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
schedule.move_operation(
precedence_sfg_delays.find_by_name("ADD3")[0].graph_id, 4
)
schedule.move_operation(
precedence_sfg_delays.find_by_name("A2")[0].graph_id, 2
)
schedule.move_operation(precedence_sfg_delays.find_by_name("A2")[0].graph_id, 2)
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
......@@ -271,13 +247,9 @@ class TestRescheduling:
"OUT1": 21,
}
def test_move_operation_slack_after_rescheduling(
self, precedence_sfg_delays
):
def test_move_operation_slack_after_rescheduling(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
add3_id = precedence_sfg_delays.find_by_name("ADD3")[0].graph_id
......@@ -297,34 +269,97 @@ class TestRescheduling:
assert schedule.forward_slack(a2_id) == 2
assert schedule.backward_slack(a2_id) == 18
def test_move_operation_incorrect_move_backward(
self, precedence_sfg_delays
):
def test_move_operation_incorrect_move_backward(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Operation add4 got incorrect move: -4. Must be between 0 and 7.",
):
schedule.move_operation(
precedence_sfg_delays.find_by_name("ADD3")[0].graph_id, -4
)
def test_move_operation_incorrect_move_forward(
self, precedence_sfg_delays
):
def test_move_operation_incorrect_move_forward(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type(
ConstantMultiplication.type_name(), 3
)
precedence_sfg_delays.set_latency_of_type(ConstantMultiplication.type_name(), 3)
schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="Operation add4 got incorrect move: 10. Must be between 0 and 7.",
):
schedule.move_operation(
precedence_sfg_delays.find_by_name("ADD3")[0].graph_id, 10
)
def test_move_operation_acc(self):
in0 = Input()
d = Delay()
a = d + in0
out0 = Output(a)
d << a
sfg = SFG([in0], [out0])
sfg.set_latency_of_type(Addition.type_name(), 1)
schedule = Schedule(sfg, cyclic=True)
# Check initial conditions
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule._start_times["add1"] == 0
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 0
assert schedule._start_times["out1"] == 1
# Move and scheduling algorithm behaves differently
schedule.move_operation("out1", 0)
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule._start_times["out1"] == 0
assert schedule._start_times["add1"] == 0
# Increase schedule time
schedule.set_schedule_time(2)
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule._start_times["out1"] == 0
assert schedule._start_times["add1"] == 0
# Move out one time unit
schedule.move_operation("out1", 1)
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule._start_times["out1"] == 1
assert schedule._start_times["add1"] == 0
# Move add one time unit
schedule.move_operation("add1", 1)
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 1
assert schedule._start_times["add1"] == 1
assert schedule._start_times["out1"] == 1
# Move out back one time unit
schedule.move_operation("out1", -1)
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 1
assert schedule._start_times["out1"] == 0
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule._start_times["add1"] == 1
# Move add back one time unit
schedule.move_operation("add1", -1)
assert schedule.laps[sfg.find_by_id("add1").input(0).signals[0].graph_id] == 1
assert schedule.laps[sfg.find_by_id("add1").input(1).signals[0].graph_id] == 0
assert schedule.laps[sfg.find_by_id("out1").input(0).signals[0].graph_id] == 1
assert schedule._start_times["add1"] == 0
assert schedule._start_times["out1"] == 0
class TestTimeResolution:
def test_increase_time_resolution(
......@@ -341,11 +376,9 @@ class TestTimeResolution:
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
start_times_names[op_name] = start_time
assert start_times_names == {
......@@ -377,11 +410,9 @@ class TestTimeResolution:
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
start_times_names[op_name] = start_time
assert start_times_names == {
......@@ -418,11 +449,9 @@ class TestTimeResolution:
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
start_times_names[op_name] = start_time
assert start_times_names == {
......@@ -437,19 +466,15 @@ class TestTimeResolution:
"OUT2": 60,
}
with pytest.raises(
ValueError, match="Not possible to decrease resolution"
):
with pytest.raises(ValueError, match="Not possible to decrease resolution"):
schedule.decrease_time_resolution(4)
schedule.decrease_time_resolution(3)
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = (
sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
)
op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
start_times_names[op_name] = start_time
assert start_times_names == {
......@@ -491,9 +516,7 @@ class TestErrors:
def test_no_output_latency(self):
in1 = Input()
in2 = Input()
bfly = Butterfly(
in1, in2, latency_offsets={"in0": 4, "in1": 2, "out0": 10}
)
bfly = Butterfly(in1, in2, latency_offsets={"in0": 4, "in1": 2, "out0": 10})
out1 = Output(bfly.output(0))
out2 = Output(bfly.output(1))
sfg = SFG([in1, in2], [out1, out2])
......@@ -504,9 +527,7 @@ class TestErrors:
Schedule(sfg)
in1 = Input()
in2 = Input()
bfly1 = Butterfly(
in1, in2, latency_offsets={"in0": 4, "in1": 2, "out1": 10}
)
bfly1 = Butterfly(in1, in2, latency_offsets={"in0": 4, "in1": 2, "out1": 10})
bfly2 = Butterfly(
bfly1.output(0),
bfly1.output(1),
......@@ -523,12 +544,8 @@ class TestErrors:
def test_too_short_schedule_time(self, sfg_simple_filter):
sfg_simple_filter.set_latency_of_type(Addition.type_name(), 5)
sfg_simple_filter.set_latency_of_type(
ConstantMultiplication.type_name(), 4
)
with pytest.raises(
ValueError, match="Too short schedule time. Minimum is 9."
):
sfg_simple_filter.set_latency_of_type(ConstantMultiplication.type_name(), 4)
with pytest.raises(ValueError, match="Too short schedule time. Minimum is 9."):
Schedule(sfg_simple_filter, schedule_time=3)
schedule = Schedule(sfg_simple_filter)
......@@ -540,9 +557,7 @@ class TestErrors:
def test_incorrect_scheduling_algorithm(self, sfg_simple_filter):
sfg_simple_filter.set_latency_of_type(Addition.type_name(), 1)
sfg_simple_filter.set_latency_of_type(
ConstantMultiplication.type_name(), 2
)
sfg_simple_filter.set_latency_of_type(ConstantMultiplication.type_name(), 2)
with pytest.raises(
NotImplementedError, match="No algorithm with name: foo defined."
):
......