"""
B-ASIC Simulation Module.

Contains a class for simulating the result of an SFG given a set of input values.
"""

from collections import defaultdict
from numbers import Number
from typing import (
    Callable,
    List,
    Mapping,
    MutableMapping,
    MutableSequence,
    Optional,
    Sequence,
    Union,
)

import numpy as np

from b_asic.operation import MutableDelayMap, ResultKey
from b_asic.signal_flow_graph import SFG
from b_asic.types import Num

ResultArrayMap = Mapping[ResultKey, Sequence[Num]]
MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Num]]
InputFunction = Callable[[int], Num]
InputProvider = Union[Num, Sequence[Num], InputFunction]


class Simulation:
    """
    Simulation of an SFG.

    Use FastSimulation (from the C++ extension module) for a more effective
    simulation when running many iterations.

    Parameters
    ----------
    sfg : SFG
        The signal flow graph to simulate.
    input_providers : list, optional
        Input values, one list item per input. Each list item can be an array of values,
        a callable taking a time index and returning the value, or a
        number (constant input). If a value is not provided for an input, it will be 0.
    """

    _sfg: SFG
    _results: MutableResultArrayMap
    _delays: MutableDelayMap
    _iteration: int
    _input_functions: List[InputFunction]
    _input_length: Optional[int]

    def __init__(
        self,
        sfg: SFG,
        input_providers: Optional[Sequence[Optional[InputProvider]]] = None,
    ):
        """Construct a Simulation of an SFG."""
        if not isinstance(sfg, SFG):
            raise TypeError("An SFG must be provided")

        # Copy the SFG to make sure it's not modified from the outside.
        self._sfg = sfg()
        self._results = defaultdict(list)
        self._delays = {}
        self._iteration = 0
        self._input_functions = [lambda _: 0 for _ in range(self._sfg.input_count)]
        self._input_length = None
        if input_providers is not None:
            self.set_inputs(input_providers)

    def set_input(self, index: int, input_provider: InputProvider) -> None:
        """
        Set the input used to get values for the specific input at the given index of\
 the internal SFG.

        Parameters
        ----------
        index : int
            The input index.
        input_provider : list, callable, or number
            Can be an array of values, a callable taking a time index and returning
            the value, or a number (constant input).
        """
        if index < 0 or index >= len(self._input_functions):
            raise IndexError(
                "Input index out of range (expected"
                f" 0-{len(self._input_functions) - 1}, got {index})"
            )
        if callable(input_provider):
            self._input_functions[index] = input_provider
        elif isinstance(input_provider, Number):
            self._input_functions[index] = lambda _: input_provider
        else:
            if self._input_length is None:
                self._input_length = len(input_provider)
            elif self._input_length != len(input_provider):
                raise ValueError(
                    "Inconsistent input length for simulation (was"
                    f" {self._input_length}, got {len(input_provider)})"
                )
            self._input_functions[index] = lambda n: input_provider[n]

    def set_inputs(self, input_providers: Sequence[Optional[InputProvider]]) -> None:
        """
        Set the input functions used to get values for the inputs to the internal SFG.
        """
        if len(input_providers) != self._sfg.input_count:
            raise ValueError(
                "Wrong number of inputs supplied to simulation (expected"
                f" {self._sfg.input_count}, got {len(input_providers)})"
            )
        for index, input_provider in enumerate(input_providers):
            if input_provider is not None:
                self.set_input(index, input_provider)

    def step(
        self,
        save_results: bool = True,
        bits_override: Optional[int] = None,
        truncate: bool = True,
    ) -> Sequence[Num]:
        """Run one iteration of the simulation and return the resulting output values.
        """
        return self.run_for(1, save_results, bits_override, truncate)

    def run_until(
        self,
        iteration: int,
        save_results: bool = True,
        bits_override: Optional[int] = None,
        truncate: bool = True,
    ) -> Sequence[Num]:
        """
        Run the simulation until its iteration is greater than or equal to the given\
 iteration and return the output values of the last iteration.
        """
        result: Sequence[Num] = []
        while self._iteration < iteration:
            input_values = [
                self._input_functions[i](self._iteration)
                for i in range(self._sfg.input_count)
            ]
            results = {}
            result = self._sfg.evaluate_outputs(
                input_values,
                results,
                self._delays,
                "",
                bits_override,
                truncate,
            )
            if save_results:
                for key, value in results.items():
                    self._results[key].append(value)
            self._iteration += 1
        return result

    def run_for(
        self,
        iterations: int,
        save_results: bool = True,
        bits_override: Optional[int] = None,
        truncate: bool = True,
    ) -> Sequence[Num]:
        """
        Run a given number of iterations of the simulation and return the output\
 values of the last iteration.
        """
        return self.run_until(
            self._iteration + iterations, save_results, bits_override, truncate
        )

    def run(
        self,
        save_results: bool = True,
        bits_override: Optional[int] = None,
        truncate: bool = True,
    ) -> Sequence[Num]:
        """
        Run the simulation until the end of its input arrays and return the output\
 values of the last iteration.
        """
        if self._input_length is None:
            raise IndexError("Tried to run unlimited simulation")
        return self.run_until(self._input_length, save_results, bits_override, truncate)

    @property
    def iteration(self) -> int:
        """Get the current iteration number of the simulation."""
        return self._iteration

    @property
    def results(self) -> ResultArrayMap:
        """
        Get a mapping from result keys to numpy arrays containing all results.

        This includes intermediate values, calculated for each iteration up until now
        that was run with *save_results* enabled.
        The mapping is indexed using the ``key()`` method of Operation with the
        appropriate output index.
        Example result after 3 iterations::

            {"c1": [3, 6, 7], "c2": [4, 5, 5], "bfly1.0": [7, 0, 0], "bfly1.1": [-1, 0, 2], "0": [7, -2, -1]}
        """
        return {key: np.array(value) for key, value in self._results.items()}

    def clear_results(self) -> None:
        """Clear all results that were saved until now."""
        self._results.clear()

    def clear_state(self) -> None:
        """
        Clear all current state of the simulation, except for the results and iteration.
        """
        self._delays.clear()

    def show(self) -> None:
        """Show the simulation results."""
        # import here to avoid cyclic imports
        from b_asic.gui_utils.plot_window import start_simulation_dialog

        start_simulation_dialog(self.results, self._sfg.name)