diff --git a/b_asic/resources.py b/b_asic/resources.py index 97dffb6802daef8d1306f646add9dd090a43a7cc..bc850f1f29654948b013a6468eca197de66d659a 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -907,7 +907,7 @@ class ProcessCollection: def split_on_ports( self, - heuristic: str = "left_edge", + heuristic: str = "graph_color", read_ports: Optional[int] = None, write_ports: Optional[int] = None, total_ports: Optional[int] = None, diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 1ecfbb2d0f53dbeeb63cf754ad8fae4eaf2f5731..c4d5de352e7d2eafa936060ffc6aae9f9886c9f3 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -1,3 +1,4 @@ +import sys from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional, cast @@ -9,6 +10,7 @@ from b_asic.types import TypeName if TYPE_CHECKING: from b_asic.operation import Operation from b_asic.schedule import Schedule + from b_asic.signal_flow_graph import SFG from b_asic.types import GraphID @@ -44,10 +46,15 @@ class Scheduler(ABC): ] + cast(int, source_port.latency_offset) +# TODO: Rename max_concurrent_reads/writes to max_concurrent_read_ports or something to signify difference + + class ListScheduler(Scheduler, ABC): def __init__( self, max_resources: Optional[dict[TypeName, int]] = None, + max_concurrent_reads: Optional[int] = None, + max_concurrent_writes: Optional[int] = None, input_times: Optional[dict["GraphID", int]] = None, output_delta_times: Optional[dict["GraphID", int]] = None, cyclic: Optional[bool] = False, @@ -65,6 +72,13 @@ class ListScheduler(Scheduler, ABC): else: self._max_resources = {} + self._max_concurrent_reads = ( + max_concurrent_reads if max_concurrent_reads else sys.maxsize + ) + self._max_concurrent_writes = ( + max_concurrent_writes if max_concurrent_writes else sys.maxsize + ) + self._input_times = input_times if input_times else {} self._output_delta_times = output_delta_times if output_delta_times else {} @@ -77,45 +91,63 @@ class ListScheduler(Scheduler, ABC): Schedule to apply the scheduling algorithm on. """ sfg = schedule.sfg - start_times = schedule.start_times used_resources_ready_times = {} remaining_resources = self._max_resources.copy() sorted_operations = self._get_sorted_operations(schedule) + schedule.start_times = {} + + remaining_reads = self._max_concurrent_reads + # initial input placement if self._input_times: for input_id in self._input_times: - start_times[input_id] = self._input_times[input_id] + schedule.start_times[input_id] = self._input_times[input_id] for input_op in sfg.find_by_type_name(Input.type_name()): if input_op.graph_id not in self._input_times: - start_times[input_op.graph_id] = 0 + schedule.start_times[input_op.graph_id] = 0 current_time = 0 + timeout_counter = 0 while sorted_operations: # generate the best schedulable candidate candidate = sfg.find_by_id(sorted_operations[0]) counter = 0 while not self._candidate_is_schedulable( - start_times, + schedule.start_times, + sfg, candidate, current_time, remaining_resources, + remaining_reads, + self._max_concurrent_writes, sorted_operations, ): if counter == len(sorted_operations): counter = 0 current_time += 1 + timeout_counter += 1 + + if timeout_counter > 10: + msg = "Algorithm did not schedule any operation for 10 time steps, try relaxing constraints." + raise TimeoutError(msg) + + remaining_reads = self._max_concurrent_reads + # update available operators for operation, ready_time in used_resources_ready_times.items(): if ready_time == current_time: remaining_resources[operation.type_name()] += 1 + else: candidate = sfg.find_by_id(sorted_operations[counter]) counter += 1 + timeout_counter = 0 + # if the resource is constrained, update remaining resources if candidate.type_name() in remaining_resources: remaining_resources[candidate.type_name()] -= 1 @@ -128,9 +160,11 @@ class ListScheduler(Scheduler, ABC): current_time + candidate.latency ) + remaining_reads -= candidate.input_count + # schedule the best candidate to the current time sorted_operations.remove(candidate.graph_id) - start_times[candidate.graph_id] = current_time + schedule.start_times[candidate.graph_id] = current_time self._handle_outputs(schedule) @@ -138,6 +172,7 @@ class ListScheduler(Scheduler, ABC): max_start_time = max(schedule.start_times.values()) if current_time < max_start_time: current_time = max_start_time + current_time = max(current_time, schedule.get_max_end_time()) schedule.set_schedule_time(current_time) schedule.remove_delays() @@ -152,9 +187,12 @@ class ListScheduler(Scheduler, ABC): @staticmethod def _candidate_is_schedulable( start_times: dict["GraphID"], + sfg: "SFG", operation: "Operation", current_time: int, remaining_resources: dict["GraphID", int], + remaining_reads: int, + max_concurrent_writes: int, remaining_ops: list["GraphID"], ) -> bool: if ( @@ -163,20 +201,51 @@ class ListScheduler(Scheduler, ABC): ): return False + op_finish_time = current_time + operation.latency + future_ops = [ + sfg.find_by_id(item[0]) + for item in start_times.items() + if item[1] + sfg.find_by_id(item[0]).latency == op_finish_time + ] + + future_ops_writes = sum([op.input_count for op in future_ops]) + + if ( + not operation.graph_id.startswith("out") + and future_ops_writes >= max_concurrent_writes + ): + return False + + read_counter = 0 earliest_start_time = 0 for op_input in operation.inputs: source_op = op_input.signals[0].source.operation + if isinstance(source_op, Delay): + continue + source_op_graph_id = source_op.graph_id if source_op_graph_id in remaining_ops: return False + if start_times[source_op_graph_id] != current_time - 1: + # not a direct connection -> memory read required + read_counter += 1 + + if read_counter > remaining_reads: + return False + proceeding_op_start_time = start_times.get(source_op_graph_id) + proceeding_op_finish_time = proceeding_op_start_time + source_op.latency + + # if not proceeding_op_finish_time == current_time: + # # not direct connection -> memory required, check if okay + # satisfying_remaining_reads = remaining_reads >= operation.input_count + # satisfying_remaining_writes = remaining_writes >= operation.output_count + # if not (satisfying_remaining_reads and satisfying_remaining_writes): + # return False - if not isinstance(source_op, Delay): - earliest_start_time = max( - earliest_start_time, proceeding_op_start_time + source_op.latency - ) + earliest_start_time = max(earliest_start_time, proceeding_op_finish_time) return earliest_start_time <= current_time diff --git a/examples/auto_scheduling_with_custom_io_times.py b/examples/auto_scheduling_with_custom_io_times.py index 6b6a90b614ac5adb8fe1ff00490b0f7b3f61f270..8913bfd853bfe650ba1546aa83488097618a435b 100644 --- a/examples/auto_scheduling_with_custom_io_times.py +++ b/examples/auto_scheduling_with_custom_io_times.py @@ -52,7 +52,12 @@ output_delta_times = { "out7": 5, } schedule = Schedule( - sfg, scheduler=HybridScheduler(resources, input_times, output_delta_times) + sfg, + scheduler=HybridScheduler( + resources, + input_times=input_times, + output_delta_times=output_delta_times, + ), ) schedule.show() @@ -70,7 +75,11 @@ output_delta_times = { } schedule = Schedule( sfg, - scheduler=HybridScheduler(resources, input_times, output_delta_times), + scheduler=HybridScheduler( + resources, + input_times=input_times, + output_delta_times=output_delta_times, + ), cyclic=True, ) schedule.show() diff --git a/examples/ldlt_matrix_inverse.py b/examples/ldlt_matrix_inverse.py index 4432e41764f96508d1e5c2638849bef67d2c2e4a..cf5961aaa3b8f1641e44233df0bfe86eeb96fa7d 100644 --- a/examples/ldlt_matrix_inverse.py +++ b/examples/ldlt_matrix_inverse.py @@ -84,7 +84,9 @@ output_delta_times = { } schedule = Schedule( sfg, - scheduler=HybridScheduler(resources, input_times, output_delta_times), + scheduler=HybridScheduler( + resources, input_times=input_times, output_delta_times=output_delta_times + ), cyclic=True, ) print("Scheduling time:", schedule.schedule_time) @@ -137,4 +139,3 @@ arch = Architecture( # %% arch -# schedule.edit() diff --git a/examples/memory_constrained_scheduling.py b/examples/memory_constrained_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0ea9a8e276167cca08bc3ad8d2942262f3151a --- /dev/null +++ b/examples/memory_constrained_scheduling.py @@ -0,0 +1,134 @@ +""" +========================================= +Memory Constrained Scheduling +========================================= + +""" + +from b_asic.architecture import Architecture, Memory, ProcessingElement +from b_asic.core_operations import Butterfly, ConstantMultiplication +from b_asic.core_schedulers import ASAPScheduler, HybridScheduler +from b_asic.schedule import Schedule +from b_asic.sfg_generators import radix_2_dif_fft +from b_asic.special_operations import Input, Output + +sfg = radix_2_dif_fft(points=16) + +# %% +# The SFG is +sfg + +# %% +# Set latencies and execution times. +sfg.set_latency_of_type(Butterfly.type_name(), 3) +sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2) +sfg.set_execution_time_of_type(Butterfly.type_name(), 1) +sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1) + +# # %% +# Generate an ASAP schedule for reference +schedule = Schedule(sfg, scheduler=ASAPScheduler()) +schedule.show() + +# %% +# Generate a PE constrained HybridSchedule +resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1} +schedule = Schedule(sfg, scheduler=HybridScheduler(resources)) +schedule.show() + +# %% +operations = schedule.get_operations() +bfs = operations.get_by_type_name(Butterfly.type_name()) +bfs.show(title="Butterfly executions") +const_muls = operations.get_by_type_name(ConstantMultiplication.type_name()) +const_muls.show(title="ConstMul executions") +inputs = operations.get_by_type_name(Input.type_name()) +inputs.show(title="Input executions") +outputs = operations.get_by_type_name(Output.type_name()) +outputs.show(title="Output executions") + +bf_pe = ProcessingElement(bfs, entity_name="bf") +mul_pe = ProcessingElement(const_muls, entity_name="mul") + +pe_in = ProcessingElement(inputs, entity_name='input') +pe_out = ProcessingElement(outputs, entity_name='output') + +mem_vars = schedule.get_memory_variables() +mem_vars.show(title="All memory variables") +direct, mem_vars = mem_vars.split_on_length() +mem_vars.show(title="Non-zero time memory variables") +mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2) + +# %% +memories = [] +for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + mem.show(title=f"{memory.entity_name}") + memory.assign("left_edge") + memory.show_content(title=f"Assigned {memory.entity_name}") + +direct.show(title="Direct interconnects") + +# %% +arch = Architecture( + {bf_pe, mul_pe, pe_in, pe_out}, + memories, + direct_interconnects=direct, +) +arch + +# %% +# Generate another HybridSchedule but this time constrain the amount of reads and writes to reduce the amount of memories +resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1} +schedule = Schedule( + sfg, + scheduler=HybridScheduler( + resources, max_concurrent_reads=2, max_concurrent_writes=2 + ), +) +schedule.show() + +# %% Print the max number of read and write port accesses to non-direct memories +direct, mem_vars = schedule.get_memory_variables().split_on_length() +print("Max read ports:", mem_vars.read_ports_bound()) +print("Max write ports:", mem_vars.write_ports_bound()) + +# %% Proceed to construct PEs and plot executions and non-direct memory variables +operations = schedule.get_operations() +bfs = operations.get_by_type_name(Butterfly.type_name()) +bfs.show(title="Butterfly executions") +const_muls = operations.get_by_type_name(ConstantMultiplication.type_name()) +const_muls.show(title="ConstMul executions") +inputs = operations.get_by_type_name(Input.type_name()) +inputs.show(title="Input executions") +outputs = operations.get_by_type_name(Output.type_name()) +outputs.show(title="Output executions") + +bf_pe = ProcessingElement(bfs, entity_name="bf") +mul_pe = ProcessingElement(const_muls, entity_name="mul") + +pe_in = ProcessingElement(inputs, entity_name='input') +pe_out = ProcessingElement(outputs, entity_name='output') + +mem_vars.show(title="Non-zero time memory variables") +mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2) + +# %% Allocate memories by graph-coloring +memories = [] +for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + mem.show(title=f"{memory.entity_name}") + memory.assign("left_edge") + memory.show_content(title=f"Assigned {memory.entity_name}") + +direct.show(title="Direct interconnects") + +# %% Synthesize the new architecture, now only using two memories but with data rate +arch = Architecture( + {bf_pe, mul_pe, pe_in, pe_out}, + memories, + direct_interconnects=direct, +) +arch diff --git a/test/test_core_schedulers.py b/test/test_core_schedulers.py index 296974866613df2e207446e4feb9db038c203a2d..8404db9ce0a060b1ad5813757143386f3f45af21 100644 --- a/test/test_core_schedulers.py +++ b/test/test_core_schedulers.py @@ -861,7 +861,9 @@ class TestHybridScheduler: } schedule = Schedule( sfg, - scheduler=HybridScheduler(resources, input_times, output_times), + scheduler=HybridScheduler( + resources, input_times=input_times, output_delta_times=output_times + ), cyclic=True, ) @@ -933,7 +935,9 @@ class TestHybridScheduler: } schedule = Schedule( sfg, - scheduler=HybridScheduler(resources, input_times, output_times), + scheduler=HybridScheduler( + resources, input_times=input_times, output_delta_times=output_times + ), cyclic=False, ) @@ -1027,7 +1031,9 @@ class TestHybridScheduler: } schedule = Schedule( sfg, - scheduler=HybridScheduler(resources, input_times, output_times), + scheduler=HybridScheduler( + resources, input_times=input_times, output_delta_times=output_times + ), cyclic=True, ) @@ -1078,3 +1084,41 @@ class TestHybridScheduler: resources = {MADS.type_name(): "test"} with pytest.raises(ValueError, match="max_resources value must be an integer."): Schedule(sfg, scheduler=HybridScheduler(resources)) + + # def test_ldlt_inverse_2x2_read_constrained(self): + # sfg = ldlt_matrix_inverse(N=2) + + # sfg.set_latency_of_type(MADS.type_name(), 3) + # sfg.set_latency_of_type(Reciprocal.type_name(), 2) + # sfg.set_execution_time_of_type(MADS.type_name(), 1) + # sfg.set_execution_time_of_type(Reciprocal.type_name(), 1) + + # resources = {MADS.type_name(): 1, Reciprocal.type_name(): 1} + # schedule = Schedule( + # sfg, + # scheduler=HybridScheduler( + # max_resources = resources, + # max_concurrent_reads = 3, + # ), + # ) + + def test_ldlt_inverse_2x2_read_constrained_too_low(self): + sfg = ldlt_matrix_inverse(N=2) + + sfg.set_latency_of_type(MADS.type_name(), 3) + sfg.set_latency_of_type(Reciprocal.type_name(), 2) + sfg.set_execution_time_of_type(MADS.type_name(), 1) + sfg.set_execution_time_of_type(Reciprocal.type_name(), 1) + + resources = {MADS.type_name(): 1, Reciprocal.type_name(): 1} + with pytest.raises( + TimeoutError, + match="Algorithm did not schedule any operation for 10 time steps, try relaxing constraints.", + ): + Schedule( + sfg, + scheduler=HybridScheduler( + max_resources=resources, + max_concurrent_reads=2, + ), + )