Newer
Older
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, cast
from b_asic.port import OutputPort
from b_asic.special_operations import Delay, Input, Output
from b_asic.types import TypeName
from b_asic.operation import Operation
from b_asic.schedule import Schedule
from b_asic.signal_flow_graph import SFG
class Scheduler(ABC):
@abstractmethod
def apply_scheduling(self, schedule: "Schedule") -> None:
"""Applies the scheduling algorithm on the given Schedule.
Parameters
----------
schedule : Schedule
Schedule to apply the scheduling algorithm on.
"""
raise NotImplementedError
def _handle_outputs(
self, schedule: "Schedule", non_schedulable_ops: Optional[list["GraphID"]] = []
) -> None:
for output in schedule.sfg.find_by_type_name(Output.type_name()):
output = cast(Output, output)
source_port = cast(OutputPort, output.inputs[0].signals[0].source)
if source_port.operation.graph_id in non_schedulable_ops:
schedule.start_times[output.graph_id] = 0
else:
if source_port.latency_offset is None:
raise ValueError(
f"Output port {source_port.index} of operation"
f" {source_port.operation.graph_id} has no"
" latency-offset."
)
schedule.start_times[output.graph_id] = schedule.start_times[
source_port.operation.graph_id
] + cast(int, source_port.latency_offset)
def __init__(
self,
max_resources: Optional[dict[TypeName, int]] = None,
max_concurrent_reads: Optional[int] = None,
max_concurrent_writes: Optional[int] = None,
input_times: Optional[dict["GraphID", int]] = None,
output_delta_times: Optional[dict["GraphID", int]] = None,
cyclic: Optional[bool] = False,
) -> None:
super()
if max_resources is not None:
if not isinstance(max_resources, dict):
raise ValueError("max_resources must be a dictionary.")
for key, value in max_resources.items():
if not isinstance(key, str):
raise ValueError("max_resources key must be a valid type_name.")
if not isinstance(value, int):
raise ValueError("max_resources value must be an integer.")
self._max_resources = max_resources
else:
self._max_resources = {}
self._max_concurrent_reads = max_concurrent_reads or sys.maxsize
self._max_concurrent_writes = max_concurrent_writes or sys.maxsize
self._input_times = input_times or {}
self._output_delta_times = output_delta_times or {}
def apply_scheduling(self, schedule: "Schedule") -> None:
"""Applies the scheduling algorithm on the given Schedule.
Parameters
----------
schedule : Schedule
Schedule to apply the scheduling algorithm on.
"""
used_resources_ready_times = {}
remaining_resources = self._max_resources.copy()
if Input.type_name() not in remaining_resources:
remaining_resources[Input.type_name()] = 1
if Output.type_name() not in remaining_resources:
remaining_resources[Output.type_name()] = 1
sorted_operations = self._get_sorted_operations(schedule)
schedule.start_times = {}
remaining_reads = self._max_concurrent_reads
# initial input placement
if self._input_times:
for input_id in self._input_times:
schedule.start_times[input_id] = self._input_times[input_id]
sorted_operations = [
elem for elem in sorted_operations if not elem.startswith("in")
]
timeout_counter = 0
while sorted_operations:
# generate the best schedulable candidate
candidate = sfg.find_by_id(sorted_operations[0])
counter = 0
while not self._candidate_is_schedulable(
schedule.start_times,
sfg,
candidate,
current_time,
remaining_resources,
remaining_reads,
self._max_concurrent_writes,
sorted_operations,
):
if counter == len(sorted_operations):
counter = 0
current_time += 1
timeout_counter += 1
if timeout_counter > 10:
msg = "Algorithm did not schedule any operation for 10 time steps, try relaxing constraints."
raise TimeoutError(msg)
remaining_reads = self._max_concurrent_reads
# update available operators
for operation, ready_time in used_resources_ready_times.items():
if ready_time == current_time:
remaining_resources[operation.type_name()] += 1
else:
candidate = sfg.find_by_id(sorted_operations[counter])
counter += 1
timeout_counter = 0
# if the resource is constrained, update remaining resources
if candidate.type_name() in remaining_resources:
remaining_resources[candidate.type_name()] -= 1
if (
candidate.type_name() == Input.type_name()
or candidate.type_name() == Output.type_name()
):
used_resources_ready_times[candidate] = current_time + 1
used_resources_ready_times[candidate] = (
current_time + candidate.execution_time
remaining_reads -= candidate.input_count
# schedule the best candidate to the current time
sorted_operations.remove(candidate.graph_id)
schedule.start_times[candidate.graph_id] = current_time
if not schedule.cyclic:
max_start_time = max(schedule.start_times.values())
if current_time < max_start_time:
current_time = max_start_time
current_time = max(current_time, schedule.get_max_end_time())
schedule.set_schedule_time(current_time)
# move all dont cares ALAP
for dc_op in schedule.sfg.find_by_type_name(DontCare.type_name()):
dc_op = cast(DontCare, dc_op)
schedule.move_operation_alap(dc_op.graph_id)
def _candidate_is_schedulable(
start_times: dict["GraphID"],
sfg: "SFG",
operation: "Operation",
current_time: int,
remaining_resources: dict["GraphID", int],
remaining_reads: int,
max_concurrent_writes: int,
remaining_ops: list["GraphID"],
) -> bool:
if (
operation.type_name() in remaining_resources
and remaining_resources[operation.type_name()] == 0
):
return False
op_finish_time = current_time + operation.latency
future_ops = [
sfg.find_by_id(item[0])
for item in start_times.items()
if item[1] + sfg.find_by_id(item[0]).latency == op_finish_time
]
future_ops_writes = sum([op.input_count for op in future_ops])
if (
not operation.graph_id.startswith("out")
and future_ops_writes >= max_concurrent_writes
):
return False
read_counter = 0
earliest_start_time = 0
for op_input in operation.inputs:
source_op = op_input.signals[0].source.operation
if isinstance(source_op, Delay):
continue
source_op_graph_id = source_op.graph_id
if source_op_graph_id in remaining_ops:
return False
if start_times[source_op_graph_id] != current_time - 1:
# not a direct connection -> memory read required
read_counter += 1
if read_counter > remaining_reads:
return False
proceeding_op_start_time = start_times.get(source_op_graph_id)
proceeding_op_finish_time = proceeding_op_start_time + source_op.latency
earliest_start_time = max(earliest_start_time, proceeding_op_finish_time)
return earliest_start_time <= current_time
@abstractmethod
def _get_sorted_operations(schedule: "Schedule") -> list["GraphID"]:
raise NotImplementedError
def _handle_outputs(
self, schedule: "Schedule", non_schedulable_ops: Optional[list["GraphID"]] = []
) -> None:
schedule.set_schedule_time(schedule.get_max_end_time())
for output in schedule.sfg.find_by_type_name(Output.type_name()):
output = cast(Output, output)
if output.graph_id in self._output_delta_times:
delta_time = self._output_delta_times[output.graph_id]
if schedule.cyclic:
schedule.start_times[output.graph_id] = schedule.schedule_time
schedule.move_operation(output.graph_id, delta_time)
else:
schedule.start_times[output.graph_id] = (
schedule.schedule_time + delta_time
)