import copy import sys from abc import ABC, abstractmethod from math import ceil from typing import TYPE_CHECKING, Optional, cast import b_asic.logger as logger from b_asic.core_operations import DontCare from b_asic.port import OutputPort from b_asic.special_operations import Delay, Input, Output from b_asic.types import TypeName if TYPE_CHECKING: from b_asic.operation import Operation from b_asic.schedule import Schedule from b_asic.types import GraphID class Scheduler(ABC): @abstractmethod def apply_scheduling(self, schedule: "Schedule") -> None: """Applies the scheduling algorithm on the given Schedule. Parameters ---------- schedule : Schedule Schedule to apply the scheduling algorithm on. """ raise NotImplementedError def _handle_outputs( self, schedule: "Schedule", non_schedulable_ops: Optional[list["GraphID"]] = [] ) -> None: for output in schedule.sfg.find_by_type_name(Output.type_name()): output = cast(Output, output) source_port = cast(OutputPort, output.inputs[0].signals[0].source) if source_port.operation.graph_id in non_schedulable_ops: schedule.start_times[output.graph_id] = 0 else: if source_port.latency_offset is None: raise ValueError( f"Output port {source_port.index} of operation" f" {source_port.operation.graph_id} has no" " latency-offset." ) schedule.start_times[output.graph_id] = schedule.start_times[ source_port.operation.graph_id ] + cast(int, source_port.latency_offset) class ASAPScheduler(Scheduler): """Scheduler that implements the as-soon-as-possible (ASAP) algorithm.""" def apply_scheduling(self, schedule: "Schedule") -> None: """Applies the scheduling algorithm on the given Schedule. Parameters ---------- schedule : Schedule Schedule to apply the scheduling algorithm on. """ prec_list = schedule.sfg.get_precedence_list() if len(prec_list) < 2: raise ValueError("Empty signal flow graph cannot be scheduled.") # handle the first set in precedence graph (input and delays) non_schedulable_ops = [] for outport in prec_list[0]: operation = outport.operation if operation.type_name() == Delay.type_name(): non_schedulable_ops.append(operation.graph_id) else: schedule.start_times[operation.graph_id] = 0 # handle second set in precedence graph (first operations) for outport in prec_list[1]: operation = outport.operation schedule.start_times[operation.graph_id] = 0 # handle the remaining sets for outports in prec_list[2:]: for outport in outports: operation = outport.operation if operation.graph_id not in schedule.start_times: op_start_time = 0 for current_input in operation.inputs: source_port = current_input.signals[0].source if source_port.operation.graph_id in non_schedulable_ops: source_end_time = 0 else: source_op_time = schedule.start_times[ source_port.operation.graph_id ] if source_port.latency_offset is None: raise ValueError( f"Output port {source_port.index} of" " operation" f" {source_port.operation.graph_id} has no" " latency-offset." ) source_end_time = ( source_op_time + source_port.latency_offset ) if current_input.latency_offset is None: raise ValueError( f"Input port {current_input.index} of operation" f" {current_input.operation.graph_id} has no" " latency-offset." ) op_start_time_from_in = ( source_end_time - current_input.latency_offset ) op_start_time = max(op_start_time, op_start_time_from_in) schedule.start_times[operation.graph_id] = op_start_time self._handle_outputs(schedule, non_schedulable_ops) schedule.remove_delays() max_end_time = schedule.get_max_end_time() if schedule.schedule_time is None: schedule.set_schedule_time(max_end_time) elif schedule.schedule_time < max_end_time: raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.") schedule.sort_y_locations_on_start_times() class ALAPScheduler(Scheduler): """Scheduler that implements the as-late-as-possible (ALAP) algorithm.""" def apply_scheduling(self, schedule: "Schedule") -> None: """Applies the scheduling algorithm on the given Schedule. Parameters ---------- schedule : Schedule Schedule to apply the scheduling algorithm on. """ ASAPScheduler().apply_scheduling(schedule) # move all outputs ALAP before operations for output in schedule.sfg.find_by_type_name(Output.type_name()): output = cast(Output, output) schedule.move_operation_alap(output.graph_id) # move all operations ALAP for step in reversed(schedule.sfg.get_precedence_list()): for outport in step: if not isinstance(outport.operation, Delay): schedule.move_operation_alap(outport.operation.graph_id) schedule.sort_y_locations_on_start_times() class ListScheduler(Scheduler, ABC): TIME_OUT_COUNTER_LIMIT = 100 def __init__( self, max_resources: Optional[dict[TypeName, int]] = None, max_concurrent_reads: Optional[int] = None, max_concurrent_writes: Optional[int] = None, input_times: Optional[dict["GraphID", int]] = None, output_delta_times: Optional[dict["GraphID", int]] = None, cyclic: Optional[bool] = False, ) -> None: super() self._logger = logger.getLogger(__name__, "list_scheduler.log", "DEBUG") if max_resources is not None: if not isinstance(max_resources, dict): raise ValueError("max_resources must be a dictionary.") for key, value in max_resources.items(): if not isinstance(key, str): raise ValueError("max_resources key must be a valid type_name.") if not isinstance(value, int): raise ValueError("max_resources value must be an integer.") self._max_resources = max_resources else: self._max_resources = {} if Input.type_name() not in self._max_resources: self._max_resources[Input.type_name()] = 1 if Output.type_name() not in self._max_resources: self._max_resources[Output.type_name()] = 1 self._max_concurrent_reads = max_concurrent_reads or sys.maxsize self._max_concurrent_writes = max_concurrent_writes or sys.maxsize self._input_times = input_times or {} self._output_delta_times = output_delta_times or {} @property @abstractmethod def sort_indices(self) -> tuple[tuple[int, bool]]: raise NotImplementedError def apply_scheduling(self, schedule: "Schedule") -> None: """Applies the scheduling algorithm on the given Schedule. Parameters ---------- schedule : Schedule Schedule to apply the scheduling algorithm on. """ self._logger.debug("--- Scheduler initializing ---") self._schedule = schedule self._sfg = schedule.sfg if self._schedule.cyclic and self._schedule.schedule_time is None: raise ValueError("Scheduling time must be provided when cyclic = True.") for resource_type, resource_amount in self._max_resources.items(): total_exec_time = sum( [op.execution_time for op in self._sfg.find_by_type_name(resource_type)] ) if self._schedule.schedule_time is not None: resource_lower_bound = ceil( total_exec_time / self._schedule.schedule_time ) if resource_amount < resource_lower_bound: raise ValueError( f"Amount of resource: {resource_type} is not enough to " f"realize schedule for scheduling time: {self._schedule.schedule_time}." ) alap_schedule = copy.copy(self._schedule) alap_schedule._schedule_time = None ALAPScheduler().apply_scheduling(alap_schedule) alap_start_times = alap_schedule.start_times self._schedule.start_times = {} if not self._schedule.cyclic and self._schedule.schedule_time: if alap_schedule.schedule_time > self._schedule.schedule_time: raise ValueError( f"Provided scheduling time {schedule.schedule_time} cannot be reached, " "try to enable the cyclic property or increase the time to at least " f"{alap_schedule.schedule_time}." ) self._remaining_resources = self._max_resources.copy() self._remaining_ops = self._sfg.operations self._remaining_ops = [op.graph_id for op in self._remaining_ops] self._cached_latencies = { op_id: self._sfg.find_by_id(op_id).latency for op_id in self._remaining_ops } self._cached_execution_times = { op_id: self._sfg.find_by_id(op_id).execution_time for op_id in self._remaining_ops } self._deadlines = self._calculate_deadlines(alap_start_times) self._output_slacks = self._calculate_alap_output_slacks(alap_start_times) self._fan_outs = self._calculate_fan_outs(alap_start_times) self._schedule.start_times = {} self.remaining_reads = self._max_concurrent_reads self._current_time = 0 self._time_out_counter = 0 self._op_laps = {} self._remaining_ops = [ op for op in self._remaining_ops if not op.startswith("dontcare") ] self._remaining_ops = [ op for op in self._remaining_ops if not op.startswith("t") ] self._remaining_ops = [ op for op in self._remaining_ops if not (op.startswith("out") and op in self._output_delta_times) ] if self._input_times: self._logger.debug("--- Input placement starting ---") for input_id in self._input_times: self._schedule.start_times[input_id] = self._input_times[input_id] self._op_laps[input_id] = 0 self._logger.debug( f" {input_id} time: {self._schedule.start_times[input_id]}" ) self._remaining_ops = [ elem for elem in self._remaining_ops if not elem.startswith("in") ] self._logger.debug("--- Input placement completed ---") self._logger.debug("--- Operation scheduling starting ---") while self._remaining_ops: ready_ops_priority_table = self._get_ready_ops_priority_table() while ready_ops_priority_table: next_op = self._sfg.find_by_id( self._get_next_op_id(ready_ops_priority_table) ) self.remaining_reads -= next_op.input_count self._remaining_ops = [ op_id for op_id in self._remaining_ops if op_id != next_op.graph_id ] self._time_out_counter = 0 self._schedule.place_operation(next_op, self._current_time) self._op_laps[next_op.graph_id] = ( (self._current_time) // self._schedule.schedule_time if self._schedule.schedule_time else 0 ) if self._schedule.schedule_time is not None: self._logger.debug( f" Op: {next_op.graph_id}, time: {self._current_time % self._schedule.schedule_time}" ) else: self._logger.debug( f" Op: {next_op.graph_id}, time: {self._current_time}" ) ready_ops_priority_table = self._get_ready_ops_priority_table() self._go_to_next_time_step() self.remaining_reads = self._max_concurrent_reads self._logger.debug("--- Operation scheduling completed ---") self._current_time -= 1 if self._output_delta_times: self._handle_outputs() if self._schedule.schedule_time is None: self._schedule.set_schedule_time(self._schedule.get_max_end_time()) self._schedule.remove_delays() # schedule all dont cares ALAP for dc_op in self._sfg.find_by_type_name(DontCare.type_name()): dc_op = cast(DontCare, dc_op) self._schedule.start_times[dc_op.graph_id] = 0 self._schedule.move_operation_alap(dc_op.graph_id) self._schedule.sort_y_locations_on_start_times() self._logger.debug("--- Scheduling completed ---") def _go_to_next_time_step(self): self._time_out_counter += 1 if self._time_out_counter >= self.TIME_OUT_COUNTER_LIMIT: raise TimeoutError( "Algorithm did not manage to schedule any operation for 10 time steps, " "try relaxing the constraints." ) self._current_time += 1 def _get_next_op_id( self, ready_ops_priority_table: list[tuple["GraphID", int, ...]] ) -> "GraphID": def sort_key(item): return tuple( (item[index] * (-1 if not asc else 1),) for index, asc in self.sort_indices ) sorted_table = sorted(ready_ops_priority_table, key=sort_key) return sorted_table[0][0] def _get_ready_ops_priority_table(self) -> list[tuple["GraphID", int, int, int]]: ready_ops = [ op_id for op_id in self._remaining_ops if self._op_is_schedulable(self._sfg.find_by_id(op_id)) ] return [ ( op_id, self._deadlines[op_id], self._output_slacks[op_id], self._fan_outs[op_id], ) for op_id in ready_ops ] def _calculate_deadlines( self, alap_start_times: dict["GraphID", int] ) -> dict["GraphID", int]: return { op_id: start_time + self._cached_latencies[op_id] for op_id, start_time in alap_start_times.items() } def _calculate_alap_output_slacks( self, alap_start_times: dict["GraphID", int] ) -> dict["GraphID", int]: return {op_id: start_time for op_id, start_time in alap_start_times.items()} def _calculate_fan_outs( self, alap_start_times: dict["GraphID", int] ) -> dict["GraphID", int]: return { op_id: len(self._sfg.find_by_id(op_id).output_signals) for op_id, start_time in alap_start_times.items() } def _op_satisfies_resource_constraints(self, op: "Operation") -> bool: if self._schedule.schedule_time is not None: time_slot = self._current_time % self._schedule.schedule_time else: time_slot = self._current_time count = 0 for op_id, start_time in self._schedule.start_times.items(): if self._schedule.schedule_time is not None: start_time = start_time % self._schedule.schedule_time if time_slot >= start_time: if time_slot < start_time + max(self._cached_execution_times[op_id], 1): if op_id.startswith(op.type_name()): if op.graph_id != op_id: count += 1 return count < self._remaining_resources[op.type_name()] def _op_is_schedulable(self, op: "Operation") -> bool: if not self._op_satisfies_resource_constraints(op): return False op_finish_time = self._current_time + self._cached_latencies[op.graph_id] future_ops = [ self._sfg.find_by_id(item[0]) for item in self._schedule.start_times.items() if item[1] + self._cached_latencies[item[0]] == op_finish_time ] future_ops_writes = sum([op.input_count for op in future_ops]) if ( not op.graph_id.startswith("out") and future_ops_writes >= self._max_concurrent_writes ): return False read_counter = 0 earliest_start_time = 0 for op_input in op.inputs: source_op = op_input.signals[0].source.operation if isinstance(source_op, Delay) or isinstance(source_op, DontCare): continue source_op_graph_id = source_op.graph_id if source_op_graph_id in self._remaining_ops: return False if self._schedule.start_times[source_op_graph_id] != self._current_time - 1: # not a direct connection -> memory read required read_counter += 1 if read_counter > self.remaining_reads: return False if self._schedule.schedule_time is not None: proceeding_op_start_time = ( self._schedule.start_times.get(source_op_graph_id) + self._op_laps[source_op.graph_id] * self._schedule.schedule_time ) proceeding_op_finish_time = ( proceeding_op_start_time + self._cached_latencies[source_op.graph_id] ) else: proceeding_op_start_time = self._schedule.start_times.get( source_op_graph_id ) proceeding_op_finish_time = ( proceeding_op_start_time + self._cached_latencies[source_op.graph_id] ) earliest_start_time = max(earliest_start_time, proceeding_op_finish_time) return earliest_start_time <= self._current_time def _handle_outputs(self) -> None: self._logger.debug("--- Output placement starting ---") if self._schedule.cyclic: end = self._schedule.schedule_time else: end = self._schedule.get_max_end_time() for output in self._sfg.find_by_type_name(Output.type_name()): output = cast(Output, output) if output.graph_id in self._output_delta_times: delta_time = self._output_delta_times[output.graph_id] new_time = end + delta_time if self._schedule.cyclic and self._schedule.schedule_time is not None: self._schedule.place_operation(output, new_time) else: self._schedule.start_times[output.graph_id] = new_time count = -1 for op_id, time in self._schedule.start_times.items(): if time == new_time and op_id.startswith("out"): count += 1 self._remaining_resources = self._max_resources self._remaining_resources[Output.type_name()] -= count self._current_time = new_time if not self._op_is_schedulable(output): raise ValueError( "Cannot schedule outputs according to the provided output_delta_times. " f"Failed output: {output.graph_id}, " f"at time: { self._schedule.start_times[output.graph_id]}, " "try relaxing the constraints." ) modulo_time = ( new_time % self._schedule.schedule_time if self._schedule.schedule_time else new_time ) self._logger.debug(f" {output.graph_id} time: {modulo_time}") self._logger.debug("--- Output placement completed ---")