diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 6b136259088adcf193c28b88f74466a8c09ee5ba..b81879dfb271b8eae1a174318819305ddb111374 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -15,6 +15,48 @@ if TYPE_CHECKING: class Scheduler(ABC): + def __init__( + self, + input_times: dict["GraphID", int] | None = None, + output_delta_times: dict["GraphID", int] | None = None, + ): + self._logger = logger.getLogger("scheduler") + self._op_laps = {} + + if input_times is not None: + if not isinstance(input_times, dict): + raise ValueError("Provided input_times must be a dictionary.") + for key, value in input_times.items(): + if not isinstance(key, str): + raise ValueError("Provided input_times keys must be strings.") + if not isinstance(value, int): + raise ValueError("Provided input_times values must be integers.") + if any(time < 0 for time in input_times.values()): + raise ValueError("Provided input_times values must be non-negative.") + self._input_times = input_times + else: + self._input_times = {} + + if output_delta_times is not None: + if not isinstance(output_delta_times, dict): + raise ValueError("Provided output_delta_times must be a dictionary.") + for key, value in output_delta_times.items(): + if not isinstance(key, str): + raise ValueError( + "Provided output_delta_times keys must be strings." + ) + if not isinstance(value, int): + raise ValueError( + "Provided output_delta_times values must be integers." + ) + if any(time < 0 for time in output_delta_times.values()): + raise ValueError( + "Provided output_delta_times values must be non-negative." + ) + self._output_delta_times = output_delta_times + else: + self._output_delta_times = {} + @abstractmethod def apply_scheduling(self, schedule: "Schedule") -> None: """Applies the scheduling algorithm on the given Schedule. @@ -26,7 +68,17 @@ class Scheduler(ABC): """ raise NotImplementedError - def _handle_outputs( + def _place_inputs_on_given_times(self) -> None: + self._logger.debug("--- Input placement starting ---") + for input_id in self._input_times: + self._schedule.start_times[input_id] = self._input_times[input_id] + self._op_laps[input_id] = 0 + self._logger.debug( + f" {input_id} time: {self._schedule.start_times[input_id]}" + ) + self._logger.debug("--- Input placement completed ---") + + def _place_outputs_asap( self, schedule: "Schedule", non_schedulable_ops: list["GraphID"] | None = [] ) -> None: for output in schedule.sfg.find_by_type(Output): @@ -45,6 +97,72 @@ class Scheduler(ABC): source_port.operation.graph_id ] + cast(int, source_port.latency_offset) + def _place_outputs_on_given_times(self) -> None: + self._logger.debug("--- Output placement starting ---") + if self._schedule._cyclic and isinstance(self, ListScheduler): + end = self._schedule._schedule_time + else: + end = self._schedule.get_max_end_time() + for output in self._sfg.find_by_type(Output): + output = cast(Output, output) + if output.graph_id in self._output_delta_times: + delta_time = self._output_delta_times[output.graph_id] + + new_time = end + delta_time + + if ( + self._schedule._cyclic + and self._schedule._schedule_time is not None + and isinstance(self, ListScheduler) + ): + self._schedule.place_operation(output, new_time, self._op_laps) + else: + self._schedule.start_times[output.graph_id] = new_time + + count = -1 + for op_id, time in self._schedule.start_times.items(): + if time == new_time and isinstance( + self._sfg.find_by_id(op_id), Output + ): + count += 1 + + modulo_time = ( + new_time % self._schedule._schedule_time + if self._schedule._schedule_time + else new_time + ) + self._logger.debug(f" {output.graph_id} time: {modulo_time}") + self._logger.debug("--- Output placement completed ---") + + self._logger.debug("--- Output placement optimization starting ---") + min_slack = min( + self._schedule.backward_slack(op.graph_id) + for op in self._sfg.find_by_type(Output) + ) + if min_slack != 0: + for output in self._sfg.find_by_type(Output): + if self._schedule._cyclic and self._schedule._schedule_time is not None: + self._schedule.move_operation(output.graph_id, -min_slack) + else: + self._schedule.start_times[output.graph_id] = ( + self._schedule.start_times[output.graph_id] - min_slack + ) + new_time = self._schedule.start_times[output.graph_id] + if ( + not self._schedule._cyclic + and self._schedule._schedule_time is not None + and new_time > self._schedule._schedule_time + ): + raise ValueError( + f"Cannot place output {output.graph_id} at time {new_time} " + f"for scheduling time {self._schedule._schedule_time}. " + "Try to relax the scheduling time, change the output delta times or enable cyclic." + ) + self._logger.debug( + f" {output.graph_id} moved {min_slack} time steps backwards to new time {new_time}" + ) + self._logger.debug("--- Output placement optimization completed ---") + class ASAPScheduler(Scheduler): """Scheduler that implements the as-soon-as-possible (ASAP) algorithm.""" @@ -57,27 +175,26 @@ class ASAPScheduler(Scheduler): schedule : Schedule Schedule to apply the scheduling algorithm on. """ - + self._schedule = schedule + self._sfg = schedule._sfg prec_list = schedule.sfg.get_precedence_list() if len(prec_list) < 2: raise ValueError("Empty signal flow graph cannot be scheduled.") + if self._input_times: + self._place_inputs_on_given_times() + # handle the first set in precedence graph (input and delays) non_schedulable_ops = [] for outport in prec_list[0]: operation = outport.operation if operation.type_name() == Delay.type_name(): non_schedulable_ops.append(operation.graph_id) - else: + elif operation.graph_id not in self._input_times: schedule.start_times[operation.graph_id] = 0 - # handle second set in precedence graph (first operations) - for outport in prec_list[1]: - operation = outport.operation - schedule.start_times[operation.graph_id] = 0 - # handle the remaining sets - for outports in prec_list[2:]: + for outports in prec_list[1:]: for outport in outports: operation = outport.operation if operation.graph_id not in schedule.start_times: @@ -117,7 +234,9 @@ class ASAPScheduler(Scheduler): schedule.start_times[operation.graph_id] = op_start_time - self._handle_outputs(schedule, non_schedulable_ops) + self._place_outputs_asap(schedule, non_schedulable_ops) + if self._input_times: + self._place_outputs_on_given_times() schedule.remove_delays() max_end_time = schedule.get_max_end_time() @@ -141,26 +260,41 @@ class ALAPScheduler(Scheduler): schedule : Schedule Schedule to apply the scheduling algorithm on. """ - ASAPScheduler().apply_scheduling(schedule) - self.op_laps = {} + self._schedule = schedule + self._sfg = schedule._sfg + ASAPScheduler( + self._input_times, + self._output_delta_times, + ).apply_scheduling(schedule) + self._op_laps = {} + + if self._output_delta_times: + self._place_outputs_on_given_times() + + for output in schedule.sfg.find_by_type(Input): + output = cast(Output, output) + self._op_laps[output.graph_id] = 0 # move all outputs ALAP before operations for output in schedule.sfg.find_by_type(Output): output = cast(Output, output) - self.op_laps[output.graph_id] = 0 + self._op_laps[output.graph_id] = 0 + if output.graph_id in self._output_delta_times: + continue schedule.move_operation_alap(output.graph_id) # move all operations ALAP for step in reversed(schedule.sfg.get_precedence_list()): for outport in step: - if not isinstance(outport.operation, Delay): + op = outport.operation + if not isinstance(op, Delay) and op.graph_id not in self._input_times: new_unwrapped_start_time = schedule.start_times[ - outport.operation.graph_id - ] + schedule.forward_slack(outport.operation.graph_id) - self.op_laps[outport.operation.graph_id] = ( + op.graph_id + ] + schedule.forward_slack(op.graph_id) + self._op_laps[op.graph_id] = ( new_unwrapped_start_time // schedule._schedule_time ) - schedule.move_operation_alap(outport.operation.graph_id) + schedule.move_operation_alap(op.graph_id) # adjust the scheduling time if empty time slots have appeared in the start slack = min(schedule.start_times.values()) @@ -202,8 +336,8 @@ class ListScheduler(Scheduler): input_times: dict["GraphID", int] | None = None, output_delta_times: dict["GraphID", int] | None = None, ) -> None: - super() - self._logger = logger.getLogger("list_scheduler") + super().__init__(input_times, output_delta_times) + self._sort_order = sort_order if max_resources is not None: if not isinstance(max_resources, dict): @@ -233,42 +367,6 @@ class ListScheduler(Scheduler): ) self._max_concurrent_writes = max_concurrent_writes or 0 - if input_times is not None: - if not isinstance(input_times, dict): - raise ValueError("Provided input_times must be a dictionary.") - for key, value in input_times.items(): - if not isinstance(key, str): - raise ValueError("Provided input_times keys must be strings.") - if not isinstance(value, int): - raise ValueError("Provided input_times values must be integers.") - if any(time < 0 for time in input_times.values()): - raise ValueError("Provided input_times values must be non-negative.") - self._input_times = input_times - else: - self._input_times = {} - - if output_delta_times is not None: - if not isinstance(output_delta_times, dict): - raise ValueError("Provided output_delta_times must be a dictionary.") - for key, value in output_delta_times.items(): - if not isinstance(key, str): - raise ValueError( - "Provided output_delta_times keys must be strings." - ) - if not isinstance(value, int): - raise ValueError( - "Provided output_delta_times values must be integers." - ) - if any(time < 0 for time in output_delta_times.values()): - raise ValueError( - "Provided output_delta_times values must be non-negative." - ) - self._output_delta_times = output_delta_times - else: - self._output_delta_times = {} - - self._sort_order = sort_order - def apply_scheduling(self, schedule: "Schedule") -> None: """Applies the scheduling algorithm on the given Schedule. @@ -288,11 +386,14 @@ class ListScheduler(Scheduler): if self._input_times: self._place_inputs_on_given_times() + self._remaining_ops = [ + op_id for op_id in self._remaining_ops if op_id not in self._input_times + ] self._schedule_nonrecursive_ops() if self._output_delta_times: - self._handle_outputs() + self._place_outputs_on_given_times() if self._schedule._schedule_time is None: self._schedule.set_schedule_time(self._schedule.get_max_end_time()) @@ -574,10 +675,10 @@ class ListScheduler(Scheduler): alap_schedule = copy.copy(self._schedule) alap_schedule._schedule_time = None - alap_scheduler = ALAPScheduler() + alap_scheduler = ALAPScheduler(self._input_times, self._output_delta_times) alap_scheduler.apply_scheduling(alap_schedule) self._alap_start_times = alap_schedule.start_times - self._alap_op_laps = alap_scheduler.op_laps + self._alap_op_laps = alap_scheduler._op_laps self._alap_schedule_time = alap_schedule._schedule_time self._schedule.start_times = {} for key in self._schedule._laps: @@ -639,7 +740,6 @@ class ListScheduler(Scheduler): self._used_reads = {0: 0} self._current_time = 0 - self._op_laps = {} def _schedule_nonrecursive_ops(self) -> None: self._logger.debug("--- Non-Recursive Operation scheduling starting ---") @@ -697,83 +797,6 @@ class ListScheduler(Scheduler): else: self._used_reads[time] = 1 - def _place_inputs_on_given_times(self) -> None: - self._logger.debug("--- Input placement starting ---") - for input_id in self._input_times: - self._schedule.start_times[input_id] = self._input_times[input_id] - self._op_laps[input_id] = 0 - self._logger.debug( - f" {input_id} time: {self._schedule.start_times[input_id]}" - ) - self._remaining_ops = [ - op_id - for op_id in self._remaining_ops - if not isinstance(self._sfg.find_by_id(op_id), Input) - ] - self._logger.debug("--- Input placement completed ---") - - def _handle_outputs(self) -> None: - self._logger.debug("--- Output placement starting ---") - if self._schedule._cyclic: - end = self._schedule._schedule_time - else: - end = self._schedule.get_max_end_time() - for output in self._sfg.find_by_type(Output): - output = cast(Output, output) - if output.graph_id in self._output_delta_times: - delta_time = self._output_delta_times[output.graph_id] - - new_time = end + delta_time - - if self._schedule._cyclic and self._schedule._schedule_time is not None: - self._schedule.place_operation(output, new_time, self._op_laps) - else: - self._schedule.start_times[output.graph_id] = new_time - - count = -1 - for op_id, time in self._schedule.start_times.items(): - if time == new_time and isinstance( - self._sfg.find_by_id(op_id), Output - ): - count += 1 - - modulo_time = ( - new_time % self._schedule._schedule_time - if self._schedule._schedule_time - else new_time - ) - self._logger.debug(f" {output.graph_id} time: {modulo_time}") - self._logger.debug("--- Output placement completed ---") - - self._logger.debug("--- Output placement optimization starting ---") - min_slack = min( - self._schedule.backward_slack(op.graph_id) - for op in self._sfg.find_by_type(Output) - ) - if min_slack != 0: - for output in self._sfg.find_by_type(Output): - if self._schedule._cyclic and self._schedule._schedule_time is not None: - self._schedule.move_operation(output.graph_id, -min_slack) - else: - self._schedule.start_times[output.graph_id] = ( - self._schedule.start_times[output.graph_id] - min_slack - ) - new_time = self._schedule.start_times[output.graph_id] - if ( - not self._schedule._cyclic - and self._schedule._schedule_time is not None - and new_time > self._schedule._schedule_time - ): - raise ValueError( - f"Cannot place output {output.graph_id} at time {new_time} " - f"for scheduling time {self._schedule._schedule_time}. " - "Try to relax the scheduling time, change the output delta times or enable cyclic." - ) - self._logger.debug( - f" {output.graph_id} moved {min_slack} time steps backwards to new time {new_time}" - ) - self._logger.debug("--- Output placement optimization completed ---") - def _handle_dont_cares(self) -> None: # schedule all dont cares ALAP for dc_op in self._sfg.find_by_type(DontCare): @@ -808,6 +831,9 @@ class RecursiveListScheduler(ListScheduler): if self._input_times: self._place_inputs_on_given_times() + self._remaining_ops = [ + op_id for op_id in self._remaining_ops if op_id not in self._input_times + ] loops = self._sfg.loops if loops: @@ -816,7 +842,7 @@ class RecursiveListScheduler(ListScheduler): self._schedule_nonrecursive_ops() if self._output_delta_times: - self._handle_outputs() + self._place_outputs_on_given_times() if self._schedule._schedule_time is None: self._schedule.set_schedule_time(self._schedule.get_max_end_time()) diff --git a/examples/auto_scheduling_with_custom_io_times.py b/examples/auto_scheduling_with_custom_io_times.py index 3e70b0524677156fce66106430c1d2a0d470bfeb..381d982d1ed7c19e6fb3510d21a888a856f77f67 100644 --- a/examples/auto_scheduling_with_custom_io_times.py +++ b/examples/auto_scheduling_with_custom_io_times.py @@ -8,10 +8,13 @@ It is possible to specify the IO times and provide those to the scheduling. from b_asic.core_operations import Butterfly, ConstantMultiplication from b_asic.list_schedulers import HybridScheduler +from b_asic.logger import getLogger from b_asic.schedule import Schedule -from b_asic.scheduler import ASAPScheduler +from b_asic.scheduler import ALAPScheduler, ASAPScheduler from b_asic.sfg_generators import radix_2_dif_fft +getLogger("list_scheduler", console_log_level="debug") + points = 8 sfg = radix_2_dif_fft(points=points) @@ -27,16 +30,21 @@ sfg.set_execution_time_of_type_name(Butterfly.type_name(), 1) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) # %% -# Generate an ASAP schedule for reference. -schedule1 = Schedule(sfg, scheduler=ASAPScheduler()) +# Generate an ASAP schedule for reference with custom IO times. +input_times = {f"in{i}": i for i in range(points)} +output_delta_times = {f"out{i}": i for i in range(points)} +schedule1 = Schedule(sfg, scheduler=ASAPScheduler(input_times, output_delta_times)) schedule1.show() +# %% +# Generate an ALAP schedule for reference with custom IO times.. +schedule_t = Schedule(sfg, scheduler=ALAPScheduler(input_times, output_delta_times)) +schedule_t.show() + # %% # Generate a non-cyclic Schedule from HybridScheduler with custom IO times, -# one input and output per time unit +# one input and output per time unit and one butterfly/multiplication per time unit. resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1} -input_times = {f"in{i}": i for i in range(points)} -output_delta_times = {f"out{i}": i for i in range(points)} schedule2 = Schedule( sfg, scheduler=HybridScheduler( diff --git a/test/unit/test_list_schedulers.py b/test/unit/test_list_schedulers.py index fbba4a3378c41bd1d8219238aad196042a803734..edff8a5da4d874f608872084130cb0ceccfb553c 100644 --- a/test/unit/test_list_schedulers.py +++ b/test/unit/test_list_schedulers.py @@ -1202,7 +1202,7 @@ class TestHybridScheduler: for i in range(POINTS): assert schedule.start_times[f"in{i}"] == i - assert schedule.start_times[f"out{i}"] == 95 + i + assert schedule.start_times[f"out{i}"] == 81 + i # too slow for pipeline timeout # def test_64_point_fft_custom_io_times(self): @@ -1258,12 +1258,7 @@ class TestHybridScheduler: for i in range(POINTS): assert schedule.start_times[f"in{i}"] == i - if i == 0: - expected_value = 95 - elif i == 1: - expected_value = 96 - else: - expected_value = i - 1 + expected_value = ((81 + i - 1) % 96) + 1 assert schedule.start_times[f"out{i}"] == expected_value def test_cyclic_scheduling(self): diff --git a/test/unit/test_schedule.py b/test/unit/test_schedule.py index b23c76a0865deb3e802698d77963d0792ce1ac29..5795dc64939a2291fe07df2fa927572c83d865ba 100644 --- a/test/unit/test_schedule.py +++ b/test/unit/test_schedule.py @@ -788,7 +788,7 @@ class TestErrors: def test_no_latency(self, sfg_simple_filter): with pytest.raises( ValueError, - match="Input port 0 of operation add0 has no latency-offset.", + match="Input port 0 of operation cmul0 has no latency-offset.", ): Schedule(sfg_simple_filter, scheduler=ASAPScheduler())