Skip to content
Snippets Groups Projects
simulation.py 4.91 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    
    B-ASIC Simulation Module.
    TODO: More info.
    """
    
    
    from numbers import Number
    
    from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping, Union, Optional
    
    from b_asic.operation import ResultKey, ResultMap
    from b_asic.signal_flow_graph import SFG
    
    
    InputProvider = Union[Number, Sequence[Number], Callable[[int], Number]]
    
    
    class Simulation:
        """Simulation.
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
        TODO: More info.
        """
    
        _sfg: SFG
        _results: DefaultDict[int, Dict[str, Number]]
        _registers: Dict[str, Number]
        _iteration: int
        _input_functions: Sequence[Callable[[int], Number]]
        _current_input_values: Sequence[Number]
        _latest_output_values: Sequence[Number]
        _save_results: bool
    
        def __init__(self, sfg: SFG, input_providers: Optional[Sequence[Optional[InputProvider]]] = None, save_results: bool = False):
            self._sfg = sfg
            self._results = defaultdict(dict)
            self._registers = {}
            self._iteration = 0
            self._input_functions = [lambda _: 0 for _ in range(self._sfg.input_count)]
            self._current_input_values = [0 for _ in range(self._sfg.input_count)]
            self._latest_output_values = [0 for _ in range(self._sfg.output_count)]
            self._save_results = save_results
            if input_providers is not None:
                self.set_inputs(input_providers)
    
        def set_input(self, index: int, input_provider: InputProvider) -> None:
            """Set the input function used to get values for the specific input at the given index to the internal SFG."""
            if index < 0 or index >= len(self._input_functions):
                raise IndexError(f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})")
            if callable(input_provider):
                self._input_functions[index] = input_provider
            elif isinstance(input_provider, Number):
                self._input_functions[index] = lambda _: input_provider
            else:
                self._input_functions[index] = lambda n: input_provider[n]
    
        def set_inputs(self, input_providers: Sequence[Optional[InputProvider]]) -> None:
            """Set the input functions used to get values for the inputs to the internal SFG."""
            if len(input_providers) != self._sfg.input_count:
                raise ValueError(f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})")
            self._input_functions = [None for _ in range(self._sfg.input_count)]
            for index, input_provider in enumerate(input_providers):
                if input_provider is not None:
                    self.set_input(index, input_provider)
    
        @property
        def save_results(self) -> bool:
            """Get the flag that determines if the results of ."""
            return self._save_results
    
        @save_results.setter
        def save_results(self, save_results) -> None:
            self._save_results = save_results
    
        def run(self) -> Sequence[Number]:
            """Run one iteration of the simulation and return the resulting output values."""
            return self.run_for(1)
    
        def run_until(self, iteration: int) -> Sequence[Number]:
            """Run the simulation until its iteration is greater than or equal to the given iteration
            and return the resulting output values.
            """
            while self._iteration < iteration:
                self._current_input_values = [self._input_functions[i](self._iteration) for i in range(self._sfg.input_count)]
                self._latest_output_values = self._sfg.evaluate_outputs(self._current_input_values, self._results[self._iteration], self._registers)
                if not self._save_results:
                    del self._results[self.iteration]
                self._iteration += 1
            return self._latest_output_values
    
        def run_for(self, iterations: int) -> Sequence[Number]:
            """Run a given number of iterations of the simulation and return the resulting output values."""
            return self.run_until(self._iteration + iterations)
    
        @property
        def iteration(self) -> int:
            """Get the current iteration number of the simulation."""
            return self._iteration
    
        @property
        def results(self) -> Mapping[int, ResultMap]:
            """Get a mapping of all results, including intermediate values, calculated for each iteration up until now.
            The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values.
            Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}}
            """
            return self._results
    
        def clear_results(self) -> None:
            """Clear all results that were saved until now."""
            self._results.clear()
    
        def clear_state(self) -> None:
            """Clear all current state of the simulation, except for the results and iteration."""
            self._registers.clear()
            self._current_input_values = [0 for _ in range(self._sfg.input_count)]
            self._latest_output_values = [0 for _ in range(self._sfg.output_count)]