"""Advanced timeout handling.

Set of helper classes to handle timeouts of tasks with advanced options
like zones and freezing of timeouts.
"""
from __future__ import annotations

import asyncio
import enum
from types import TracebackType
from typing import Any

from .async_ import run_callback_threadsafe

ZONE_GLOBAL = "global"


class _State(str, enum.Enum):
    """States of a task."""

    INIT = "INIT"
    ACTIVE = "ACTIVE"
    TIMEOUT = "TIMEOUT"
    EXIT = "EXIT"


class _GlobalFreezeContext:
    """Context manager that freezes the global timeout."""

    def __init__(self, manager: TimeoutManager) -> None:
        """Initialize internal timeout context manager."""
        self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
        self._manager: TimeoutManager = manager

    async def __aenter__(self) -> _GlobalFreezeContext:
        self._enter()
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> bool | None:
        self._exit()
        return None

    def __enter__(self) -> _GlobalFreezeContext:
        self._loop.call_soon_threadsafe(self._enter)
        return self

    def __exit__(  # pylint: disable=useless-return
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> bool | None:
        self._loop.call_soon_threadsafe(self._exit)
        return None

    def _enter(self) -> None:
        """Run freeze."""
        if not self._manager.freezes_done:
            return

        # Global reset
        for task in self._manager.global_tasks:
            task.pause()

        # Zones reset
        for zone in self._manager.zones.values():
            if not zone.freezes_done:
                continue
            zone.pause()

        self._manager.global_freezes.append(self)

    def _exit(self) -> None:
        """Finish freeze."""
        self._manager.global_freezes.remove(self)
        if not self._manager.freezes_done:
            return

        # Global reset
        for task in self._manager.global_tasks:
            task.reset()

        # Zones reset
        for zone in self._manager.zones.values():
            if not zone.freezes_done:
                continue
            zone.reset()


class _ZoneFreezeContext:
    """Context manager that freezes a zone timeout."""

    def __init__(self, zone: _ZoneTimeoutManager) -> None:
        """Initialize internal timeout context manager."""
        self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
        self._zone: _ZoneTimeoutManager = zone

    async def __aenter__(self) -> _ZoneFreezeContext:
        self._enter()
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> bool | None:
        self._exit()
        return None

    def __enter__(self) -> _ZoneFreezeContext:
        self._loop.call_soon_threadsafe(self._enter)
        return self

    def __exit__(  # pylint: disable=useless-return
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> bool | None:
        self._loop.call_soon_threadsafe(self._exit)
        return None

    def _enter(self) -> None:
        """Run freeze."""
        if self._zone.freezes_done:
            self._zone.pause()
        self._zone.enter_freeze(self)

    def _exit(self) -> None:
        """Finish freeze."""
        self._zone.exit_freeze(self)
        if not self._zone.freezes_done:
            return
        self._zone.reset()


class _GlobalTaskContext:
    """Context manager that tracks a global task."""

    def __init__(
        self,
        manager: TimeoutManager,
        task: asyncio.Task[Any],
        timeout: float,
        cool_down: float,
    ) -> None:
        """Initialize internal timeout context manager."""
        self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
        self._manager: TimeoutManager = manager
        self._task: asyncio.Task[Any] = task
        self._time_left: float = timeout
        self._expiration_time: float | None = None
        self._timeout_handler: asyncio.Handle | None = None
        self._wait_zone: asyncio.Event = asyncio.Event()
        self._state: _State = _State.INIT
        self._cool_down: float = cool_down

    async def __aenter__(self) -> _GlobalTaskContext:
        self._manager.global_tasks.append(self)
        self._start_timer()
        self._state = _State.ACTIVE
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> bool | None:
        self._stop_timer()
        self._manager.global_tasks.remove(self)

        # Timeout on exit
        if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
            raise asyncio.TimeoutError

        self._state = _State.EXIT
        self._wait_zone.set()
        return None

    @property
    def state(self) -> _State:
        """Return state of the Global task."""
        return self._state

    def zones_done_signal(self) -> None:
        """Signal that all zones are done."""
        self._wait_zone.set()

    def _start_timer(self) -> None:
        """Start timeout handler."""
        if self._timeout_handler:
            return

        self._expiration_time = self._loop.time() + self._time_left
        self._timeout_handler = self._loop.call_at(
            self._expiration_time, self._on_timeout
        )

    def _stop_timer(self) -> None:
        """Stop zone timer."""
        if self._timeout_handler is None:
            return

        self._timeout_handler.cancel()
        self._timeout_handler = None
        # Calculate new timeout
        assert self._expiration_time
        self._time_left = self._expiration_time - self._loop.time()

    def _on_timeout(self) -> None:
        """Process timeout."""
        self._state = _State.TIMEOUT
        self._timeout_handler = None

        # Reset timer if zones are running
        if not self._manager.zones_done:
            asyncio.create_task(self._on_wait())
        else:
            self._cancel_task()

    def _cancel_task(self) -> None:
        """Cancel own task."""
        if self._task.done():
            return
        self._task.cancel()

    def pause(self) -> None:
        """Pause timers while it freeze."""
        self._stop_timer()

    def reset(self) -> None:
        """Reset timer after freeze."""
        self._start_timer()

    async def _on_wait(self) -> None:
        """Wait until zones are done."""
        await self._wait_zone.wait()
        await asyncio.sleep(self._cool_down)  # Allow context switch
        if self.state != _State.TIMEOUT:
            return
        self._cancel_task()


class _ZoneTaskContext:
    """Context manager that tracks an active task for a zone."""

    def __init__(
        self,
        zone: _ZoneTimeoutManager,
        task: asyncio.Task[Any],
        timeout: float,
    ) -> None:
        """Initialize internal timeout context manager."""
        self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
        self._zone: _ZoneTimeoutManager = zone
        self._task: asyncio.Task[Any] = task
        self._state: _State = _State.INIT
        self._time_left: float = timeout
        self._expiration_time: float | None = None
        self._timeout_handler: asyncio.Handle | None = None

    @property
    def state(self) -> _State:
        """Return state of the Zone task."""
        return self._state

    async def __aenter__(self) -> _ZoneTaskContext:
        self._zone.enter_task(self)
        self._state = _State.ACTIVE

        # Zone is on freeze
        if self._zone.freezes_done:
            self._start_timer()

        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> bool | None:
        self._zone.exit_task(self)
        self._stop_timer()

        # Timeout on exit
        if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
            raise asyncio.TimeoutError

        self._state = _State.EXIT
        return None

    def _start_timer(self) -> None:
        """Start timeout handler."""
        if self._timeout_handler:
            return

        self._expiration_time = self._loop.time() + self._time_left
        self._timeout_handler = self._loop.call_at(
            self._expiration_time, self._on_timeout
        )

    def _stop_timer(self) -> None:
        """Stop zone timer."""
        if self._timeout_handler is None:
            return

        self._timeout_handler.cancel()
        self._timeout_handler = None
        # Calculate new timeout
        assert self._expiration_time
        self._time_left = self._expiration_time - self._loop.time()

    def _on_timeout(self) -> None:
        """Process timeout."""
        self._state = _State.TIMEOUT
        self._timeout_handler = None

        # Timeout
        if self._task.done():
            return
        self._task.cancel()

    def pause(self) -> None:
        """Pause timers while it freeze."""
        self._stop_timer()

    def reset(self) -> None:
        """Reset timer after freeze."""
        self._start_timer()


class _ZoneTimeoutManager:
    """Manage the timeouts for a zone."""

    def __init__(self, manager: TimeoutManager, zone: str) -> None:
        """Initialize internal timeout context manager."""
        self._manager: TimeoutManager = manager
        self._zone: str = zone
        self._tasks: list[_ZoneTaskContext] = []
        self._freezes: list[_ZoneFreezeContext] = []

    def __repr__(self) -> str:
        """Representation of a zone."""
        return f"<{self.name}: {len(self._tasks)} / {len(self._freezes)}>"

    @property
    def name(self) -> str:
        """Return Zone name."""
        return self._zone

    @property
    def active(self) -> bool:
        """Return True if zone is active."""
        return len(self._tasks) > 0 or len(self._freezes) > 0

    @property
    def freezes_done(self) -> bool:
        """Return True if all freeze are done."""
        return len(self._freezes) == 0 and self._manager.freezes_done

    def enter_task(self, task: _ZoneTaskContext) -> None:
        """Start into new Task."""
        self._tasks.append(task)

    def exit_task(self, task: _ZoneTaskContext) -> None:
        """Exit a running Task."""
        self._tasks.remove(task)

        # On latest listener
        if not self.active:
            self._manager.drop_zone(self.name)

    def enter_freeze(self, freeze: _ZoneFreezeContext) -> None:
        """Start into new freeze."""
        self._freezes.append(freeze)

    def exit_freeze(self, freeze: _ZoneFreezeContext) -> None:
        """Exit a running Freeze."""
        self._freezes.remove(freeze)

        # On latest listener
        if not self.active:
            self._manager.drop_zone(self.name)

    def pause(self) -> None:
        """Stop timers while it freeze."""
        if not self.active:
            return

        # Forward pause
        for task in self._tasks:
            task.pause()

    def reset(self) -> None:
        """Reset timer after freeze."""
        if not self.active:
            return

        # Forward reset
        for task in self._tasks:
            task.reset()


class TimeoutManager:
    """Class to manage timeouts over different zones.

    Manages both global and zone based timeouts.
    """

    def __init__(self) -> None:
        """Initialize TimeoutManager."""
        self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
        self._zones: dict[str, _ZoneTimeoutManager] = {}
        self._globals: list[_GlobalTaskContext] = []
        self._freezes: list[_GlobalFreezeContext] = []

    @property
    def zones_done(self) -> bool:
        """Return True if all zones are finished."""
        return not bool(self._zones)

    @property
    def freezes_done(self) -> bool:
        """Return True if all freezes are finished."""
        return not self._freezes

    @property
    def zones(self) -> dict[str, _ZoneTimeoutManager]:
        """Return all Zones."""
        return self._zones

    @property
    def global_tasks(self) -> list[_GlobalTaskContext]:
        """Return all global Tasks."""
        return self._globals

    @property
    def global_freezes(self) -> list[_GlobalFreezeContext]:
        """Return all global Freezes."""
        return self._freezes

    def drop_zone(self, zone_name: str) -> None:
        """Drop a zone out of scope."""
        self._zones.pop(zone_name, None)
        if self._zones:
            return

        # Signal Global task, all zones are done
        for task in self._globals:
            task.zones_done_signal()

    def async_timeout(
        self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
    ) -> _ZoneTaskContext | _GlobalTaskContext:
        """Timeout based on a zone.

        For using as Async Context Manager.
        """
        current_task: asyncio.Task[Any] | None = asyncio.current_task()
        assert current_task

        # Global Zone
        if zone_name == ZONE_GLOBAL:
            task = _GlobalTaskContext(self, current_task, timeout, cool_down)
            return task

        # Zone Handling
        if zone_name in self.zones:
            zone: _ZoneTimeoutManager = self.zones[zone_name]
        else:
            self.zones[zone_name] = zone = _ZoneTimeoutManager(self, zone_name)

        # Create Task
        return _ZoneTaskContext(zone, current_task, timeout)

    def async_freeze(
        self, zone_name: str = ZONE_GLOBAL
    ) -> _ZoneFreezeContext | _GlobalFreezeContext:
        """Freeze all timer until job is done.

        For using as Async Context Manager.
        """
        # Global Freeze
        if zone_name == ZONE_GLOBAL:
            return _GlobalFreezeContext(self)

        # Zone Freeze
        if zone_name in self.zones:
            zone: _ZoneTimeoutManager = self.zones[zone_name]
        else:
            self.zones[zone_name] = zone = _ZoneTimeoutManager(self, zone_name)

        return _ZoneFreezeContext(zone)

    def freeze(
        self, zone_name: str = ZONE_GLOBAL
    ) -> _ZoneFreezeContext | _GlobalFreezeContext:
        """Freeze all timer until job is done.

        For using as Context Manager.
        """
        return run_callback_threadsafe(
            self._loop, self.async_freeze, zone_name
        ).result()