from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, cast

from b_asic.core_operations import DontCare
from b_asic.port import OutputPort
from b_asic.special_operations import Delay, Input, Output
from b_asic.types import TypeName

if TYPE_CHECKING:
    from b_asic.operation import Operation
    from b_asic.schedule import Schedule
    from b_asic.types import GraphID


class Scheduler(ABC):
    @abstractmethod
    def apply_scheduling(self, schedule: "Schedule") -> None:
        """Applies the scheduling algorithm on the given Schedule.

        Parameters
        ----------
        schedule : Schedule
            Schedule to apply the scheduling algorithm on.
        """
        raise NotImplementedError

    def _handle_outputs(
        self, schedule: "Schedule", non_schedulable_ops: Optional[list["GraphID"]] = []
    ) -> None:
        for output in schedule.sfg.find_by_type_name(Output.type_name()):
            output = cast(Output, output)
            source_port = cast(OutputPort, output.inputs[0].signals[0].source)
            if source_port.operation.graph_id in non_schedulable_ops:
                schedule.start_times[output.graph_id] = 0
            else:
                if source_port.latency_offset is None:
                    raise ValueError(
                        f"Output port {source_port.index} of operation"
                        f" {source_port.operation.graph_id} has no"
                        " latency-offset."
                    )
                schedule.start_times[output.graph_id] = schedule.start_times[
                    source_port.operation.graph_id
                ] + cast(int, source_port.latency_offset)


class ListScheduler(Scheduler, ABC):
    def __init__(self, max_resources: Optional[dict[TypeName, int]] = None) -> None:
        if max_resources:
            if not isinstance(max_resources, dict):
                raise ValueError("max_resources must be a dictionary.")
            for key, value in max_resources.items():
                if not isinstance(key, str):
                    raise ValueError("max_resources key must be a valid type_name.")
                if not isinstance(value, int):
                    raise ValueError("max_resources value must be an integer.")

        if max_resources:
            self._max_resources = max_resources
        else:
            self._max_resources = {}

    def apply_scheduling(self, schedule: "Schedule") -> None:
        """Applies the scheduling algorithm on the given Schedule.

        Parameters
        ----------
        schedule : Schedule
            Schedule to apply the scheduling algorithm on.
        """
        sfg = schedule.sfg
        start_times = schedule.start_times

        used_resources_ready_times = {}
        remaining_resources = self._max_resources.copy()
        sorted_operations = self._get_sorted_operations(schedule)

        # place all inputs at time 0
        for input_op in sfg.find_by_type_name(Input.type_name()):
            start_times[input_op.graph_id] = 0

        current_time = 0
        while sorted_operations:

            # generate the best schedulable candidate
            candidate = sfg.find_by_id(sorted_operations[0])
            counter = 0
            while not self._candidate_is_schedulable(
                start_times,
                candidate,
                current_time,
                remaining_resources,
                sorted_operations,
            ):
                if counter == len(sorted_operations):
                    counter = 0
                    current_time += 1
                    # update available operators
                    for operation, ready_time in used_resources_ready_times.items():
                        if ready_time == current_time:
                            remaining_resources[operation.type_name()] += 1
                else:
                    candidate = sfg.find_by_id(sorted_operations[counter])
                    counter += 1

            # if the resource is constrained, update remaining resources
            if candidate.type_name() in remaining_resources:
                remaining_resources[candidate.type_name()] -= 1
                if candidate.execution_time:
                    used_resources_ready_times[candidate] = (
                        current_time + candidate.execution_time
                    )
                else:
                    used_resources_ready_times[candidate] = (
                        current_time + candidate.latency
                    )

            # schedule the best candidate to the current time
            sorted_operations.remove(candidate.graph_id)
            start_times[candidate.graph_id] = current_time

        schedule.set_schedule_time(current_time)

        self._handle_outputs(schedule)
        schedule.remove_delays()

        # move all inputs ALAP now that operations have moved
        for input_op in schedule.sfg.find_by_type_name(Input.type_name()):
            input_op = cast(Input, input_op)
            schedule.move_operation_alap(input_op.graph_id)

        # move all dont cares ALAP
        for dc_op in schedule.sfg.find_by_type_name(DontCare.type_name()):
            dc_op = cast(DontCare, dc_op)
            schedule.move_operation_alap(dc_op.graph_id)

    @staticmethod
    def _candidate_is_schedulable(
        start_times: dict["GraphID"],
        operation: "Operation",
        current_time: int,
        remaining_resources: dict["GraphID", int],
        remaining_ops: list["GraphID"],
    ) -> bool:
        if (
            operation.type_name() in remaining_resources
            and remaining_resources[operation.type_name()] == 0
        ):
            return False

        earliest_start_time = 0
        for op_input in operation.inputs:
            source_op = op_input.signals[0].source.operation
            source_op_graph_id = source_op.graph_id

            if source_op_graph_id in remaining_ops:
                return False

            proceeding_op_start_time = start_times.get(source_op_graph_id)

            if not isinstance(source_op, Delay):
                earliest_start_time = max(
                    earliest_start_time, proceeding_op_start_time + source_op.latency
                )

        return earliest_start_time <= current_time

    @abstractmethod
    def _get_sorted_operations(schedule: "Schedule") -> list["GraphID"]:
        raise NotImplementedError