88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Any
|
|
|
|
import litellm
|
|
from litellm import ModelResponse, completion
|
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def should_retry_exception(exception: Exception) -> bool:
|
|
status_code = None
|
|
|
|
if hasattr(exception, "status_code"):
|
|
status_code = exception.status_code
|
|
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
|
status_code = exception.response.status_code
|
|
|
|
if status_code is not None:
|
|
return bool(litellm._should_retry(status_code))
|
|
return True
|
|
|
|
|
|
class LLMRequestQueue:
|
|
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
|
|
rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY")
|
|
if rate_limit_delay:
|
|
delay_between_requests = float(rate_limit_delay)
|
|
|
|
rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT")
|
|
if rate_limit_concurrent:
|
|
max_concurrent = int(rate_limit_concurrent)
|
|
|
|
self.max_concurrent = max_concurrent
|
|
self.delay_between_requests = delay_between_requests
|
|
self._semaphore = threading.BoundedSemaphore(max_concurrent)
|
|
self._last_request_time = 0.0
|
|
self._lock = threading.Lock()
|
|
|
|
async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
|
try:
|
|
while not self._semaphore.acquire(timeout=0.2):
|
|
await asyncio.sleep(0.1)
|
|
|
|
with self._lock:
|
|
now = time.time()
|
|
time_since_last = now - self._last_request_time
|
|
sleep_needed = max(0, self.delay_between_requests - time_since_last)
|
|
self._last_request_time = now + sleep_needed
|
|
|
|
if sleep_needed > 0:
|
|
await asyncio.sleep(sleep_needed)
|
|
|
|
return await self._reliable_request(completion_args)
|
|
finally:
|
|
self._semaphore.release()
|
|
|
|
@retry( # type: ignore[misc]
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=8, min=8, max=64),
|
|
retry=retry_if_exception(should_retry_exception),
|
|
reraise=True,
|
|
)
|
|
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
|
response = completion(**completion_args, stream=False)
|
|
if isinstance(response, ModelResponse):
|
|
return response
|
|
self._raise_unexpected_response()
|
|
raise RuntimeError("Unreachable code")
|
|
|
|
def _raise_unexpected_response(self) -> None:
|
|
raise RuntimeError("Unexpected response type")
|
|
|
|
|
|
_global_queue: LLMRequestQueue | None = None
|
|
|
|
|
|
def get_global_queue() -> LLMRequestQueue:
|
|
global _global_queue # noqa: PLW0603
|
|
if _global_queue is None:
|
|
_global_queue = LLMRequestQueue()
|
|
return _global_queue
|