61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Awaitable, Callable, Generic, TypeVar
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry(Generic[T]):
|
|
value: T
|
|
expires_at: datetime
|
|
updated_at: datetime
|
|
|
|
|
|
class TTLCacheService:
|
|
def __init__(self) -> None:
|
|
self._entries: dict[str, CacheEntry[object]] = {}
|
|
self._locks: dict[str, asyncio.Lock] = {}
|
|
|
|
async def get_or_load(
|
|
self,
|
|
key: str,
|
|
ttl_seconds: int,
|
|
loader: Callable[[], Awaitable[T]],
|
|
) -> T:
|
|
entry = self._entries.get(key)
|
|
if entry and not self._is_expired(entry):
|
|
return entry.value # type: ignore[return-value]
|
|
|
|
lock = self._locks.setdefault(key, asyncio.Lock())
|
|
async with lock:
|
|
entry = self._entries.get(key)
|
|
if entry and not self._is_expired(entry):
|
|
return entry.value # type: ignore[return-value]
|
|
|
|
try:
|
|
value = await loader()
|
|
now = datetime.now(timezone.utc)
|
|
self._entries[key] = CacheEntry(
|
|
value=value,
|
|
expires_at=now + timedelta(seconds=ttl_seconds),
|
|
updated_at=now,
|
|
)
|
|
return value
|
|
except Exception as exc:
|
|
if entry:
|
|
logger.warning("Cache loader failed for %s, serving stale data: %s", key, exc)
|
|
return entry.value # type: ignore[return-value]
|
|
raise
|
|
|
|
@staticmethod
|
|
def _is_expired(entry: CacheEntry[object]) -> bool:
|
|
return datetime.now(timezone.utc) >= entry.expires_at
|