diff --git a/b_asic/operation.py b/b_asic/operation.py index 29ea9fef949f1957d897388e3b2d8009b4da30e3..18f0b83f014592e72761786cabacc0a631551dba 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -231,10 +231,10 @@ class Operation(GraphComponent, SignalSourceProvider): which ignores the word length specified by the input signal. The *truncate* parameter specifies whether input truncation should be enabled in the first place. If set to False, input values will be used directly without any bit truncation. - + See also ======== - + evaluate_outputs, current_output, current_outputs """ raise NotImplementedError @@ -931,7 +931,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): self._execution_time *= factor for port in [*self.inputs, *self.outputs]: port.latency_offset *= factor - + + def _decrease_time_resolution(self, factor: int): + if self._execution_time is not None: + self._execution_time = self._execution_time // factor + for port in [*self.inputs, *self.outputs]: + port.latency_offset = port.latency_offset // factor + def get_plot_coordinates( self, ) -> Tuple[List[List[Number]], List[List[Number]]]: diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 2c5278835a184a01829ce617d961ef0b28ccd93b..df492ac973446797320bce42f6ff9d458070d5a8 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -5,6 +5,7 @@ Contains the schedule class for scheduling operations in an SFG. """ import io +import math import sys from collections import defaultdict from typing import Dict, List, Optional, Tuple @@ -191,12 +192,61 @@ class Schedule: self._start_times = { k: factor * v for k, v in self._start_times.items() } - for op_id, op_start_time in self._start_times.items(): + for op_id in self._start_times: self._sfg.find_by_id(op_id)._increase_time_resolution(factor) self._schedule_time *= factor + return self + + def _get_all_times(self) -> List[int]: + """ + Return a list of all times for the schedule. Used to check how the + resolution can be modified. + """ + # Local values + ret = [self._schedule_time, *self._start_times.values()] + # Loop over operations + for op_id in self._start_times: + op = self._sfg.find_by_id(op_id) + ret += [op.execution_time, *op.latency_offsets.values()] + # Remove not set values (None) + ret = [v for v in ret if v is not None] + return ret + + def get_possible_time_resolution_decrements(self) -> List[int]: + """Return a list with possible factors to reduce time resolution.""" + vals = self._get_all_times() + maxloop = min(val for val in vals if val) + if maxloop <= 1: + return [1] + ret = [1] + for candidate in range(2, maxloop + 1): + if not any(val % candidate for val in vals): + ret.append(candidate) + return ret def decrease_time_resolution(self, factor: int) -> "Schedule": - raise NotImplementedError + """ + Decrease time resolution for a schedule. + + Parameters + ========== + + factor : int + The time resolution decrement. + """ + possible_values = self.get_possible_time_resolution_decrements() + if factor not in possible_values: + raise ValueError( + f"Not possible to decrease resolution with {factor}. Possible" + f" values are {possible_values}" + ) + self._start_times = { + k: v // factor for k, v in self._start_times.items() + } + for op_id, _ in self._start_times.items(): + self._sfg.find_by_id(op_id)._decrease_time_resolution(factor) + self._schedule_time = self._schedule_time // factor + return self def move_operation(self, op_id: GraphID, time: int) -> "Schedule": assert ( diff --git a/test/test_schedule.py b/test/test_schedule.py index 7c47603eaa5010a69f7c847fd936c4db7bc95a22..dadef84c28fbd75fb662512ce2b47374d2434894 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -305,6 +305,7 @@ class TestTimeResolution: scheduling_alg="ASAP", ) old_schedule_time = schedule.schedule_time + assert schedule.get_possible_time_resolution_decrements() == [1] schedule.increase_time_resolution(2) @@ -330,6 +331,7 @@ class TestTimeResolution: } assert 2 * old_schedule_time == schedule.schedule_time + assert schedule.get_possible_time_resolution_decrements() == [1, 2] def test_increase_time_resolution_twice( self, sfg_two_inputs_two_outputs_independent_with_cmul @@ -365,3 +367,72 @@ class TestTimeResolution: } assert 6 * old_schedule_time == schedule.schedule_time + assert schedule.get_possible_time_resolution_decrements() == [ + 1, + 2, + 3, + 6, + ] + + def test_increase_decrease_time_resolution( + self, sfg_two_inputs_two_outputs_independent_with_cmul + ): + schedule = Schedule( + sfg_two_inputs_two_outputs_independent_with_cmul, + scheduling_alg="ASAP", + ) + old_schedule_time = schedule.schedule_time + assert schedule.get_possible_time_resolution_decrements() == [1] + + schedule.increase_time_resolution(6) + + start_times_names = {} + for op_id, start_time in schedule._start_times.items(): + op_name = ( + sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id( + op_id + ).name + ) + start_times_names[op_name] = start_time + + assert start_times_names == { + "C1": 0, + "IN1": 0, + "IN2": 0, + "CMUL1": 0, + "CMUL2": 30, + "ADD1": 0, + "CMUL3": 42, + "OUT1": 54, + "OUT2": 60, + } + + with pytest.raises( + ValueError, match="Not possible to decrease resolution" + ): + schedule.decrease_time_resolution(4) + + schedule.decrease_time_resolution(3) + start_times_names = {} + for op_id, start_time in schedule._start_times.items(): + op_name = ( + sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id( + op_id + ).name + ) + start_times_names[op_name] = start_time + + assert start_times_names == { + "C1": 0, + "IN1": 0, + "IN2": 0, + "CMUL1": 0, + "CMUL2": 10, + "ADD1": 0, + "CMUL3": 14, + "OUT1": 18, + "OUT2": 20, + } + + assert 2 * old_schedule_time == schedule.schedule_time + assert schedule.get_possible_time_resolution_decrements() == [1, 2] diff --git a/test/test_sfg.py b/test/test_sfg.py index 63287b1d7b82f896f02a11af4813b2ec786fa10b..4a2831603d37d3cc3b1a8da7868719a2e105e572 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1379,10 +1379,10 @@ class TestPrecedenceGraph: " shape=square]\n\tcmul1 -> \"cmul1.0\"\n\tcmul1 [label=cmul1" " shape=square]\n\t\"add1.0\" -> t1In\n\tt1In [label=t1" " shape=square]\n\tadd1 -> \"add1.0\"\n\tadd1 [label=add1" - " shape=square]\n}\n" + " shape=square]\n}" ) - assert sfg_simple_filter.precedence_graph().source == res + assert sfg_simple_filter.precedence_graph().source in (res, res + "\n") class TestSFGGraph: @@ -1391,20 +1391,20 @@ class TestSFGGraph: "digraph {\n\trankdir=LR\n\tin1\n\tin1 -> " "add1\n\tout1\n\tt1 -> out1\n\tadd1\n\tcmul1 -> " "add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 " - "-> cmul1\n}\n" + "-> cmul1\n}" ) - assert sfg_simple_filter.sfg().source == res + assert sfg_simple_filter.sfg().source in (res, res + "\n") def test_sfg_show_id(self, sfg_simple_filter): res = ( "digraph {\n\trankdir=LR\n\tin1\n\tin1 -> add1 " "[label=s1]\n\tout1\n\tt1 -> out1 [label=s2]\n\tadd1" "\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 " - "[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}\n" + "[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}" ) - assert sfg_simple_filter.sfg(show_id=True).source == res + assert sfg_simple_filter.sfg(show_id=True).source in (res, res + "\n") def test_show_sfg_invalid_format(self, sfg_simple_filter): with pytest.raises(ValueError):