from __future__ import annotations

__all__ = (
    "AsyncCacheInfo",
    "AsyncCacheParameters",
    "AsyncLRUCacheWrapper",
    "cache",
    "lru_cache",
    "reduce",
)

import functools
import sys
from collections import OrderedDict
from collections.abc import (
    AsyncIterable,
    Awaitable,
    Callable,
    Coroutine,
    Hashable,
    Iterable,
)
from functools import update_wrapper
from inspect import iscoroutinefunction
from typing import (
    Any,
    Generic,
    NamedTuple,
    TypedDict,
    TypeVar,
    cast,
    final,
    overload,
)
from weakref import WeakKeyDictionary

from ._core._synchronization import Lock
from .lowlevel import RunVar, checkpoint

if sys.version_info >= (3, 11):
    from typing import ParamSpec
else:
    from typing_extensions import ParamSpec

T = TypeVar("T")
S = TypeVar("S")
P = ParamSpec("P")
lru_cache_items: RunVar[
    WeakKeyDictionary[
        AsyncLRUCacheWrapper[Any, Any],
        OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
    ]
] = RunVar("lru_cache_items")


class _InitialMissingType:
    pass


initial_missing: _InitialMissingType = _InitialMissingType()


class AsyncCacheInfo(NamedTuple):
    hits: int
    misses: int
    maxsize: int | None
    currsize: int


class AsyncCacheParameters(TypedDict):
    maxsize: int | None
    typed: bool
    always_checkpoint: bool


@final
class AsyncLRUCacheWrapper(Generic[P, T]):
    def __init__(
        self,
        func: Callable[..., Awaitable[T]],
        maxsize: int | None,
        typed: bool,
        always_checkpoint: bool,
    ):
        self.__wrapped__ = func
        self._hits: int = 0
        self._misses: int = 0
        self._maxsize = max(maxsize, 0) if maxsize is not None else None
        self._currsize: int = 0
        self._typed = typed
        self._always_checkpoint = always_checkpoint
        update_wrapper(self, func)

    def cache_info(self) -> AsyncCacheInfo:
        return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)

    def cache_parameters(self) -> AsyncCacheParameters:
        return {
            "maxsize": self._maxsize,
            "typed": self._typed,
            "always_checkpoint": self._always_checkpoint,
        }

    def cache_clear(self) -> None:
        if cache := lru_cache_items.get(None):
            cache.pop(self, None)
            self._hits = self._misses = self._currsize = 0

    async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        # Easy case first: if maxsize == 0, no caching is done
        if self._maxsize == 0:
            value = await self.__wrapped__(*args, **kwargs)
            self._misses += 1
            return value

        # The key is constructed as a flat tuple to avoid memory overhead
        key: tuple[Any, ...] = args
        if kwargs:
            # initial_missing is used as a separator
            key += (initial_missing,) + sum(kwargs.items(), ())

        if self._typed:
            key += tuple(type(arg) for arg in args)
            if kwargs:
                key += (initial_missing,) + tuple(type(val) for val in kwargs.values())

        try:
            cache = lru_cache_items.get()
        except LookupError:
            cache = WeakKeyDictionary()
            lru_cache_items.set(cache)

        try:
            cache_entry = cache[self]
        except KeyError:
            cache_entry = cache[self] = OrderedDict()

        cached_value: T | _InitialMissingType
        try:
            cached_value, lock = cache_entry[key]
        except KeyError:
            # We're the first task to call this function
            cached_value, lock = (
                initial_missing,
                Lock(fast_acquire=not self._always_checkpoint),
            )
            cache_entry[key] = cached_value, lock

        if lock is None:
            # The value was already cached
            self._hits += 1
            cache_entry.move_to_end(key)
            if self._always_checkpoint:
                await checkpoint()

            return cast(T, cached_value)

        async with lock:
            # Check if another task filled the cache while we acquired the lock
            if (cached_value := cache_entry[key][0]) is initial_missing:
                self._misses += 1
                if self._maxsize is not None and self._currsize >= self._maxsize:
                    cache_entry.popitem(last=False)
                else:
                    self._currsize += 1

                value = await self.__wrapped__(*args, **kwargs)
                cache_entry[key] = value, None
            else:
                # Another task filled the cache while we were waiting for the lock
                self._hits += 1
                cache_entry.move_to_end(key)
                value = cast(T, cached_value)

        return value


class _LRUCacheWrapper(Generic[T]):
    def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
        self._maxsize = maxsize
        self._typed = typed
        self._always_checkpoint = always_checkpoint

    @overload
    def __call__(  # type: ignore[overload-overlap]
        self, func: Callable[P, Coroutine[Any, Any, T]], /
    ) -> AsyncLRUCacheWrapper[P, T]: ...

    @overload
    def __call__(
        self, func: Callable[..., T], /
    ) -> functools._lru_cache_wrapper[T]: ...

    def __call__(
        self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
    ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
        if iscoroutinefunction(f):
            return AsyncLRUCacheWrapper(
                f, self._maxsize, self._typed, self._always_checkpoint
            )

        return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f)  # type: ignore[arg-type]


@overload
def cache(  # type: ignore[overload-overlap]
    func: Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T]: ...


@overload
def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...


def cache(
    func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
    """
    A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.

    This is the asynchronous equivalent to :func:`functools.cache`.

    """
    return lru_cache(maxsize=None)(func)


@overload
def lru_cache(
    *, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
) -> _LRUCacheWrapper[Any]: ...


@overload
def lru_cache(  # type: ignore[overload-overlap]
    func: Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T]: ...


@overload
def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...


def lru_cache(
    func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
    /,
    *,
    maxsize: int | None = 128,
    typed: bool = False,
    always_checkpoint: bool = False,
) -> (
    AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
):
    """
    An asynchronous version of :func:`functools.lru_cache`.

    If a synchronous function is passed, the standard library
    :func:`functools.lru_cache` is applied instead.

    :param always_checkpoint: if ``True``, every call to the cached function will be
        guaranteed to yield control to the event loop at least once

    .. note:: Caches and locks are managed on a per-event loop basis.

    """
    if func is None:
        return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)

    if not callable(func):
        raise TypeError("the first argument must be callable")

    return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)


@overload
async def reduce(
    function: Callable[[T, S], Awaitable[T]],
    iterable: Iterable[S] | AsyncIterable[S],
    /,
    initial: T,
) -> T: ...


@overload
async def reduce(
    function: Callable[[T, T], Awaitable[T]],
    iterable: Iterable[T] | AsyncIterable[T],
    /,
) -> T: ...


async def reduce(  # type: ignore[misc]
    function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
    iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
    /,
    initial: T | _InitialMissingType = initial_missing,
) -> T:
    """
    Asynchronous version of :func:`functools.reduce`.

    :param function: a coroutine function that takes two arguments: the accumulated
        value and the next element from the iterable
    :param iterable: an iterable or async iterable
    :param initial: the initial value (if missing, the first element of the iterable is
        used as the initial value)

    """
    element: Any
    function_called = False
    if isinstance(iterable, AsyncIterable):
        async_it = iterable.__aiter__()
        if initial is initial_missing:
            try:
                value = cast(T, await async_it.__anext__())
            except StopAsyncIteration:
                raise TypeError(
                    "reduce() of empty sequence with no initial value"
                ) from None
        else:
            value = cast(T, initial)

        async for element in async_it:
            value = await function(value, element)
            function_called = True
    elif isinstance(iterable, Iterable):
        it = iter(iterable)
        if initial is initial_missing:
            try:
                value = cast(T, next(it))
            except StopIteration:
                raise TypeError(
                    "reduce() of empty sequence with no initial value"
                ) from None
        else:
            value = cast(T, initial)

        for element in it:
            value = await function(value, element)
            function_called = True
    else:
        raise TypeError("reduce() argument 2 must be an iterable or async iterable")

    # Make sure there is at least one checkpoint, even if an empty iterable and an
    # initial value were given
    if not function_called:
        await checkpoint()

    return value
