50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import logging
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
from redis.exceptions import RedisError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""Sliding window rate limiter backed by Redis.
|
|
|
|
Usage as FastAPI dependency:
|
|
@router.post("/action", dependencies=[Depends(RateLimiter(times=5, seconds=60))])
|
|
"""
|
|
|
|
def __init__(self, times: int = 10, seconds: int = 60):
|
|
self.times = times
|
|
self.seconds = seconds
|
|
|
|
async def __call__(self, request: Request) -> None:
|
|
redis = getattr(request.app.state, "redis", None)
|
|
if redis is None:
|
|
return
|
|
|
|
identifier = self._get_identifier(request)
|
|
key = f"rl:{request.url.path}:{identifier}"
|
|
|
|
try:
|
|
pipe = redis.pipeline()
|
|
pipe.incr(key)
|
|
pipe.expire(key, self.seconds)
|
|
results = await pipe.execute()
|
|
current = results[0]
|
|
except RedisError:
|
|
logger.warning("Rate limiter skipped because Redis is unavailable", exc_info=True)
|
|
return
|
|
|
|
if current > self.times:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="Too many requests, please try again later",
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_identifier(request: Request) -> str:
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return request.client.host if request.client else "unknown"
|