diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 0a1b170f636b6c645d4e29ac2eebc9a950594c61..74a892c042699f9bcb3beac5d4d6e506c9795d1c 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -143,12 +143,6 @@ class ALAPScheduler(Scheduler): Schedule to apply the scheduling algorithm on. """ ASAPScheduler().apply_scheduling(schedule) - # max_end_time = schedule.get_max_end_time() - - # if schedule.schedule_time is None: - # schedule.set_schedule_time(max_end_time) - # elif schedule.schedule_time < max_end_time: - # raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.") # move all outputs ALAP before operations for output in schedule.sfg.find_by_type_name(Output.type_name()): @@ -235,7 +229,7 @@ class ListScheduler(Scheduler, ABC): ) alap_schedule = copy.copy(self._schedule) - alap_schedule.set_schedule_time(sys.maxsize) + alap_schedule._schedule_time = None ALAPScheduler().apply_scheduling(alap_schedule) alap_start_times = alap_schedule.start_times self._schedule.start_times = {} @@ -248,14 +242,9 @@ class ListScheduler(Scheduler, ABC): f"{alap_schedule.schedule_time}." ) - used_resources_ready_times = {} self._remaining_resources = self._max_resources.copy() - remaining_ops = ( - self._sfg.operations - # + self._sfg.find_by_type_name(Input.type_name()) - # + self._sfg.find_by_type_name(Output.type_name()) - ) + remaining_ops = self._sfg.operations remaining_ops = [op.graph_id for op in remaining_ops] self._schedule.start_times = {} @@ -292,24 +281,12 @@ class ListScheduler(Scheduler, ABC): self._get_next_op_id(ready_ops_priority_table) ) - if next_op.type_name() in self._remaining_resources: - self._remaining_resources[next_op.type_name()] -= 1 - if self._schedule.schedule_time is not None: - used_resources_ready_times[next_op] = ( - self._current_time + max(next_op.execution_time, 1) - ) % self._schedule.schedule_time - else: - used_resources_ready_times[next_op] = self._current_time + max( - next_op.execution_time, 1 - ) - self.remaining_reads -= next_op.input_count remaining_ops = [ op_id for op_id in remaining_ops if op_id != next_op.graph_id ] - print("Next:", next_op.graph_id, self._current_time) self._time_out_counter = 0 self._schedule.place_operation(next_op, self._current_time) self._op_laps[next_op.graph_id] = ( @@ -336,20 +313,7 @@ class ListScheduler(Scheduler, ABC): remaining_ops, ) - # update available reads and operators - if self._schedule.schedule_time is not None: - time = self._current_time % self._schedule.schedule_time - else: - time = self._current_time - self.remaining_reads = self._max_concurrent_reads - for operation, ready_time in used_resources_ready_times.items(): - if ready_time >= time: - self._remaining_resources[operation.type_name()] += 1 - - used_resources_ready_times = dict( - [pair for pair in used_resources_ready_times.items() if pair[1] > time] - ) self._current_time -= 1 @@ -435,13 +399,31 @@ class ListScheduler(Scheduler, ABC): for op_id, start_time in alap_start_times.items() } + def _op_satisfies_resource_constraints(self, op: "Operation") -> bool: + if self._schedule.schedule_time is not None: + time_slot = self._current_time % self._schedule.schedule_time + else: + time_slot = self._current_time + + count = 0 + for op_id, start_time in self._schedule.start_times.items(): + if self._schedule.schedule_time is not None: + start_time = start_time % self._schedule.schedule_time + + if time_slot >= start_time: + if time_slot < start_time + max( + self._sfg.find_by_id(op_id).execution_time, 1 + ): + if op_id.startswith(op.type_name()): + if op.graph_id != op_id: + count += 1 + + return count < self._remaining_resources[op.type_name()] + def _op_is_schedulable( self, op: "Operation", remaining_ops: list["GraphID"] ) -> bool: - if ( - op.type_name() in self._remaining_resources - and self._remaining_resources[op.type_name()] == 0 - ): + if not self._op_satisfies_resource_constraints(op): return False op_finish_time = self._current_time + op.latency diff --git a/examples/ldlt_matrix_inverse.py b/examples/ldlt_matrix_inverse.py index ffe74c83ff45830541ddda8d35e69b4b78d79314..a6525b2a114aa89e1f56fb3c5222acdc79cde0db 100644 --- a/examples/ldlt_matrix_inverse.py +++ b/examples/ldlt_matrix_inverse.py @@ -86,6 +86,7 @@ schedule = Schedule( scheduler=HybridScheduler( resources, input_times=input_times, output_delta_times=output_delta_times ), + schedule_time=32, cyclic=True, ) print("Scheduling time:", schedule.schedule_time)