Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
simulation.py 4.99 KiB
"""@package docstring
B-ASIC Simulation Module.
TODO: More info.
"""

from collections import defaultdict
from numbers import Number
from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping, Union, Optional

from b_asic.operation import OutputKey, OutputMap
from b_asic.signal_flow_graph import SFG


InputProvider = Union[Number, Sequence[Number], Callable[[int], Number]]


class Simulation:
    """Simulation.
    TODO: More info.
    """

    _sfg: SFG
    _results: DefaultDict[int, Dict[str, Number]]
    _registers: Dict[str, Number]
    _iteration: int
    _input_functions: Sequence[Callable[[int], Number]]
    _current_input_values: Sequence[Number]
    _latest_output_values: Sequence[Number]
    _save_results: bool

    def __init__(self, sfg: SFG, input_providers: Optional[Sequence[Optional[InputProvider]]] = None, save_results: bool = False):
        self._sfg = sfg
        self._results = defaultdict(dict)
        self._registers = {}
        self._iteration = 0
        self._input_functions = [
            lambda _: 0 for _ in range(self._sfg.input_count)]
        self._current_input_values = [0 for _ in range(self._sfg.input_count)]
        self._latest_output_values = [0 for _ in range(self._sfg.output_count)]
        self._save_results = save_results
        if input_providers is not None:
            self.set_inputs(input_providers)

    def set_input(self, index: int, input_provider: InputProvider) -> None:
        """Set the input function used to get values for the specific input at the given index to the internal SFG."""
        if index < 0 or index >= len(self._input_functions):
            raise IndexError(
                f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})")
        if callable(input_provider):
            self._input_functions[index] = input_provider
        elif isinstance(input_provider, Number):
            self._input_functions[index] = lambda _: input_provider
        else:
            self._input_functions[index] = lambda n: input_provider[n]

    def set_inputs(self, input_providers: Sequence[Optional[InputProvider]]) -> None:
        """Set the input functions used to get values for the inputs to the internal SFG."""
        if len(input_providers) != self._sfg.input_count:
            raise ValueError(
                f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})")
        self._input_functions = [None for _ in range(self._sfg.input_count)]
        for index, input_provider in enumerate(input_providers):
            if input_provider is not None:
                self.set_input(index, input_provider)

    @property
    def save_results(self) -> bool:
        """Get the flag that determines if the results of ."""
        return self._save_results

    @save_results.setter
    def save_results(self, save_results) -> None:
        self._save_results = save_results

    def run(self) -> Sequence[Number]:
        """Run one iteration of the simulation and return the resulting output values."""
        return self.run_for(1)

    def run_until(self, iteration: int) -> Sequence[Number]:
        """Run the simulation until its iteration is greater than or equal to the given iteration
        and return the resulting output values.
        """
        while self._iteration < iteration:
            self._current_input_values = [self._input_functions[i](
                self._iteration) for i in range(self._sfg.input_count)]
            self._latest_output_values = self._sfg.evaluate_outputs(
                self._current_input_values, self._results[self._iteration], self._registers)
            if not self._save_results:
                del self._results[self.iteration]
            self._iteration += 1
        return self._latest_output_values

    def run_for(self, iterations: int) -> Sequence[Number]:
        """Run a given number of iterations of the simulation and return the resulting output values."""
        return self.run_until(self._iteration + iterations)

    @property
    def iteration(self) -> int:
        """Get the current iteration number of the simulation."""
        return self._iteration

    @property
    def results(self) -> Mapping[int, OutputMap]:
        """Get a mapping of all results, including intermediate values, calculated for each iteration up until now.
        The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values.
        Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}}
        """
        return self._results

    def clear_results(self) -> None:
        """Clear all results that were saved until now."""
        self._results.clear()

    def clear_state(self) -> None:
        """Clear all current state of the simulation, except for the results and iteration."""
        self._registers.clear()
        self._current_input_values = [0 for _ in range(self._sfg.input_count)]
        self._latest_output_values = [0 for _ in range(self._sfg.output_count)]