"""Executor util helpers."""
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
import contextlib
import logging
import sys
from threading import Thread
import time
import traceback
from typing import Any

from .thread import async_raise

_LOGGER = logging.getLogger(__name__)

MAX_LOG_ATTEMPTS = 2

_JOIN_ATTEMPTS = 10

EXECUTOR_SHUTDOWN_TIMEOUT = 10


def _log_thread_running_at_shutdown(name: str, ident: int) -> None:
    """Log the stack of a thread that was still running at shutdown."""
    frames = sys._current_frames()  # pylint: disable=protected-access
    stack = frames.get(ident)
    formatted_stack = traceback.format_stack(stack)
    _LOGGER.warning(
        "Thread[%s] is still running at shutdown: %s",
        name,
        "".join(formatted_stack).strip(),
    )


def join_or_interrupt_threads(
    threads: set[Thread], timeout: float, log: bool
) -> set[Thread]:
    """Attempt to join or interrupt a set of threads."""
    joined = set()
    timeout_per_thread = timeout / len(threads)

    for thread in threads:
        thread.join(timeout=timeout_per_thread)

        if not thread.is_alive() or thread.ident is None:
            joined.add(thread)
            continue

        if log:
            _log_thread_running_at_shutdown(thread.name, thread.ident)

        with contextlib.suppress(SystemError):
            # SystemError at this stage is usually a race condition
            # where the thread happens to die right before we force
            # it to raise the exception
            async_raise(thread.ident, SystemExit)

    return joined


class InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
    """A ThreadPoolExecutor instance that will not deadlock on shutdown."""

    def shutdown(self, *args: Any, **kwargs: Any) -> None:
        """Shutdown with interrupt support added."""
        super().shutdown(wait=False, cancel_futures=True)
        self.join_threads_or_timeout()

    def join_threads_or_timeout(self) -> None:
        """Join threads or timeout."""
        remaining_threads = set(self._threads)
        start_time = time.monotonic()
        timeout_remaining: float = EXECUTOR_SHUTDOWN_TIMEOUT
        attempt = 0

        while True:
            if not remaining_threads:
                return

            attempt += 1

            remaining_threads -= join_or_interrupt_threads(
                remaining_threads,
                timeout_remaining / _JOIN_ATTEMPTS,
                attempt <= MAX_LOG_ATTEMPTS,
            )

            timeout_remaining = EXECUTOR_SHUTDOWN_TIMEOUT - (
                time.monotonic() - start_time
            )
            if timeout_remaining <= 0:
                return