Enhance API response handling and implement unified API standard across multiple modules. Added feature flags for unified exception handling and debug throttling in settings. Updated pagination and response formats in various viewsets to align with the new standard. Improved error handling and response validation in frontend components for better user feedback.
This commit is contained in:
176
backend/igny8_core/api/exception_handlers.py
Normal file
176
backend/igny8_core/api/exception_handlers.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Centralized Exception Handler
|
||||
Wraps all exceptions in unified format
|
||||
"""
|
||||
import logging
|
||||
from rest_framework.views import exception_handler
|
||||
from rest_framework import status
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
from .response import get_request_id, error_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def custom_exception_handler(exc, context):
|
||||
"""
|
||||
Custom exception handler that wraps all errors in unified format
|
||||
|
||||
Args:
|
||||
exc: The exception that was raised
|
||||
context: Dictionary containing request, view, args, kwargs
|
||||
|
||||
Returns:
|
||||
Response object with unified error format
|
||||
"""
|
||||
# Get request from context
|
||||
request = context.get('request')
|
||||
|
||||
# Get request ID
|
||||
request_id = get_request_id(request) if request else None
|
||||
|
||||
# Call DRF's default exception handler first
|
||||
response = exception_handler(exc, context)
|
||||
|
||||
# If DRF handled it, wrap it in unified format
|
||||
if response is not None:
|
||||
# Extract error details from DRF response
|
||||
error_message = None
|
||||
errors = None
|
||||
status_code = response.status_code
|
||||
|
||||
# Try to extract error message from response data
|
||||
if hasattr(response, 'data'):
|
||||
if isinstance(response.data, dict):
|
||||
# DRF validation errors
|
||||
if 'detail' in response.data:
|
||||
error_message = str(response.data['detail'])
|
||||
elif 'non_field_errors' in response.data:
|
||||
error_message = str(response.data['non_field_errors'][0]) if response.data['non_field_errors'] else None
|
||||
errors = response.data
|
||||
else:
|
||||
# Field-specific errors
|
||||
errors = response.data
|
||||
# Create top-level error message
|
||||
if errors:
|
||||
first_error = list(errors.values())[0] if errors else None
|
||||
if first_error and isinstance(first_error, list) and len(first_error) > 0:
|
||||
error_message = str(first_error[0])
|
||||
elif first_error:
|
||||
error_message = str(first_error)
|
||||
else:
|
||||
error_message = 'Validation failed'
|
||||
elif isinstance(response.data, list):
|
||||
# List of errors
|
||||
error_message = str(response.data[0]) if response.data else 'Validation failed'
|
||||
else:
|
||||
error_message = str(response.data)
|
||||
|
||||
# Map status codes to appropriate error messages
|
||||
if not error_message:
|
||||
if status_code == status.HTTP_400_BAD_REQUEST:
|
||||
error_message = 'Bad request'
|
||||
elif status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
error_message = 'Authentication required'
|
||||
elif status_code == status.HTTP_403_FORBIDDEN:
|
||||
error_message = 'Permission denied'
|
||||
elif status_code == status.HTTP_404_NOT_FOUND:
|
||||
error_message = 'Resource not found'
|
||||
elif status_code == status.HTTP_409_CONFLICT:
|
||||
error_message = 'Conflict'
|
||||
elif status_code == status.HTTP_422_UNPROCESSABLE_ENTITY:
|
||||
error_message = 'Validation failed'
|
||||
elif status_code == status.HTTP_429_TOO_MANY_REQUESTS:
|
||||
error_message = 'Rate limit exceeded'
|
||||
elif status_code >= 500:
|
||||
error_message = 'Internal server error'
|
||||
else:
|
||||
error_message = 'An error occurred'
|
||||
|
||||
# Prepare debug info (only in DEBUG mode)
|
||||
debug_info = None
|
||||
if settings.DEBUG:
|
||||
debug_info = {
|
||||
'exception_type': type(exc).__name__,
|
||||
'exception_message': str(exc),
|
||||
'view': context.get('view').__class__.__name__ if context.get('view') else None,
|
||||
'path': request.path if request else None,
|
||||
'method': request.method if request else None,
|
||||
}
|
||||
# Include traceback in debug mode
|
||||
import traceback
|
||||
debug_info['traceback'] = traceback.format_exc()
|
||||
|
||||
# Log the error
|
||||
if status_code >= 500:
|
||||
logger.error(
|
||||
f"Server error: {error_message}",
|
||||
extra={
|
||||
'request_id': request_id,
|
||||
'endpoint': request.path if request else None,
|
||||
'method': request.method if request else None,
|
||||
'user_id': request.user.id if request and request.user and request.user.is_authenticated else None,
|
||||
'account_id': request.account.id if request and hasattr(request, 'account') and request.account else None,
|
||||
'status_code': status_code,
|
||||
'exception_type': type(exc).__name__,
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
elif status_code >= 400:
|
||||
logger.warning(
|
||||
f"Client error: {error_message}",
|
||||
extra={
|
||||
'request_id': request_id,
|
||||
'endpoint': request.path if request else None,
|
||||
'method': request.method if request else None,
|
||||
'user_id': request.user.id if request and request.user and request.user.is_authenticated else None,
|
||||
'account_id': request.account.id if request and hasattr(request, 'account') and request.account else None,
|
||||
'status_code': status_code,
|
||||
}
|
||||
)
|
||||
|
||||
# Return unified error response
|
||||
return error_response(
|
||||
error=error_message,
|
||||
errors=errors,
|
||||
status_code=status_code,
|
||||
request=request,
|
||||
debug_info=debug_info
|
||||
)
|
||||
|
||||
# If DRF didn't handle it, it's an unhandled exception
|
||||
# Log it and return unified error response
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__}: {str(exc)}",
|
||||
extra={
|
||||
'request_id': request_id,
|
||||
'endpoint': request.path if request else None,
|
||||
'method': request.method if request else None,
|
||||
'user_id': request.user.id if request and request.user and request.user.is_authenticated else None,
|
||||
'account_id': request.account.id if request and hasattr(request, 'account') and request.account else None,
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Prepare debug info
|
||||
debug_info = None
|
||||
if settings.DEBUG:
|
||||
import traceback
|
||||
debug_info = {
|
||||
'exception_type': type(exc).__name__,
|
||||
'exception_message': str(exc),
|
||||
'view': context.get('view').__class__.__name__ if context.get('view') else None,
|
||||
'path': request.path if request else None,
|
||||
'method': request.method if request else None,
|
||||
'traceback': traceback.format_exc()
|
||||
}
|
||||
|
||||
return error_response(
|
||||
error='Internal server error',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
request=request,
|
||||
debug_info=debug_info
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
Custom pagination class for DRF to support dynamic page_size query parameter
|
||||
and unified response format
|
||||
"""
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from .response import get_request_id
|
||||
|
||||
|
||||
class CustomPageNumberPagination(PageNumberPagination):
|
||||
@@ -11,8 +13,37 @@ class CustomPageNumberPagination(PageNumberPagination):
|
||||
|
||||
Default page size: 10
|
||||
Max page size: 100
|
||||
|
||||
Returns unified format with success field
|
||||
"""
|
||||
page_size = 10
|
||||
page_size_query_param = 'page_size'
|
||||
max_page_size = 100
|
||||
|
||||
def paginate_queryset(self, queryset, request, view=None):
|
||||
"""
|
||||
Override to store request for later use in get_paginated_response
|
||||
"""
|
||||
self.request = request
|
||||
return super().paginate_queryset(queryset, request, view)
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""
|
||||
Return a paginated response with unified format including success field
|
||||
"""
|
||||
from rest_framework.response import Response
|
||||
|
||||
response_data = {
|
||||
'success': True,
|
||||
'count': self.page.paginator.count,
|
||||
'next': self.get_next_link(),
|
||||
'previous': self.get_previous_link(),
|
||||
'results': data
|
||||
}
|
||||
|
||||
# Add request_id if request is available
|
||||
if hasattr(self, 'request') and self.request:
|
||||
response_data['request_id'] = get_request_id(self.request)
|
||||
|
||||
return Response(response_data)
|
||||
|
||||
|
||||
162
backend/igny8_core/api/permissions.py
Normal file
162
backend/igny8_core/api/permissions.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Standardized Permission Classes
|
||||
Provides consistent permission checking across all endpoints
|
||||
"""
|
||||
from rest_framework import permissions
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
|
||||
class IsAuthenticatedAndActive(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires user to be authenticated and active
|
||||
Base permission for most endpoints
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Check if user is active
|
||||
if hasattr(request.user, 'is_active'):
|
||||
return request.user.is_active
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class HasTenantAccess(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires user to belong to the tenant/account
|
||||
Ensures tenant isolation
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Get account from request (set by middleware)
|
||||
account = getattr(request, 'account', None)
|
||||
|
||||
# If no account in request, try to get from user
|
||||
if not account and hasattr(request.user, 'account'):
|
||||
try:
|
||||
account = request.user.account
|
||||
except (AttributeError, Exception):
|
||||
pass
|
||||
|
||||
# Admin/Developer/System account users bypass tenant check
|
||||
if request.user and hasattr(request.user, 'is_authenticated') and request.user.is_authenticated:
|
||||
try:
|
||||
is_admin_or_dev = (hasattr(request.user, 'is_admin_or_developer') and
|
||||
request.user.is_admin_or_developer()) if request.user else False
|
||||
is_system_user = (hasattr(request.user, 'is_system_account_user') and
|
||||
request.user.is_system_account_user()) if request.user else False
|
||||
|
||||
if is_admin_or_dev or is_system_user:
|
||||
return True
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Regular users must have account access
|
||||
if account:
|
||||
# Check if user belongs to this account
|
||||
if hasattr(request.user, 'account'):
|
||||
try:
|
||||
user_account = request.user.account
|
||||
return user_account == account or user_account.id == account.id
|
||||
except (AttributeError, Exception):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class IsViewerOrAbove(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires viewer, editor, admin, or owner role
|
||||
For read-only operations
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Admin/Developer/System account users always have access
|
||||
try:
|
||||
is_admin_or_dev = (hasattr(request.user, 'is_admin_or_developer') and
|
||||
request.user.is_admin_or_developer()) if request.user else False
|
||||
is_system_user = (hasattr(request.user, 'is_system_account_user') and
|
||||
request.user.is_system_account_user()) if request.user else False
|
||||
|
||||
if is_admin_or_dev or is_system_user:
|
||||
return True
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Check user role
|
||||
if hasattr(request.user, 'role'):
|
||||
role = request.user.role
|
||||
# viewer, editor, admin, owner all have access
|
||||
return role in ['viewer', 'editor', 'admin', 'owner']
|
||||
|
||||
# If no role system, allow authenticated users
|
||||
return True
|
||||
|
||||
|
||||
class IsEditorOrAbove(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires editor, admin, or owner role
|
||||
For content operations
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Admin/Developer/System account users always have access
|
||||
try:
|
||||
is_admin_or_dev = (hasattr(request.user, 'is_admin_or_developer') and
|
||||
request.user.is_admin_or_developer()) if request.user else False
|
||||
is_system_user = (hasattr(request.user, 'is_system_account_user') and
|
||||
request.user.is_system_account_user()) if request.user else False
|
||||
|
||||
if is_admin_or_dev or is_system_user:
|
||||
return True
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Check user role
|
||||
if hasattr(request.user, 'role'):
|
||||
role = request.user.role
|
||||
# editor, admin, owner have access
|
||||
return role in ['editor', 'admin', 'owner']
|
||||
|
||||
# If no role system, allow authenticated users
|
||||
return True
|
||||
|
||||
|
||||
class IsAdminOrOwner(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires admin or owner role only
|
||||
For settings, keys, billing operations
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Admin/Developer/System account users always have access
|
||||
try:
|
||||
is_admin_or_dev = (hasattr(request.user, 'is_admin_or_developer') and
|
||||
request.user.is_admin_or_developer()) if request.user else False
|
||||
is_system_user = (hasattr(request.user, 'is_system_account_user') and
|
||||
request.user.is_system_account_user()) if request.user else False
|
||||
|
||||
if is_admin_or_dev or is_system_user:
|
||||
return True
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Check user role
|
||||
if hasattr(request.user, 'role'):
|
||||
role = request.user.role
|
||||
# admin, owner have access
|
||||
return role in ['admin', 'owner']
|
||||
|
||||
# If no role system, deny by default for security
|
||||
return False
|
||||
|
||||
|
||||
152
backend/igny8_core/api/response.py
Normal file
152
backend/igny8_core/api/response.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Unified API Response Helpers
|
||||
Provides consistent response format across all endpoints
|
||||
"""
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
import uuid
|
||||
|
||||
|
||||
def get_request_id(request):
|
||||
"""Get request ID from request object (set by middleware) or headers, or generate new one"""
|
||||
if not request:
|
||||
return None
|
||||
|
||||
# First check if middleware set request_id on request object
|
||||
if hasattr(request, 'request_id') and request.request_id:
|
||||
return request.request_id
|
||||
|
||||
# Fallback to headers
|
||||
if hasattr(request, 'META'):
|
||||
request_id = request.META.get('HTTP_X_REQUEST_ID') or request.META.get('X-Request-ID')
|
||||
if request_id:
|
||||
return request_id
|
||||
|
||||
# Generate new request ID if none found
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def success_response(data=None, message=None, status_code=status.HTTP_200_OK, request=None):
|
||||
"""
|
||||
Create a standardized success response
|
||||
|
||||
Args:
|
||||
data: Response data (dict, list, or None)
|
||||
message: Optional success message
|
||||
status_code: HTTP status code (default: 200)
|
||||
request: Request object (optional, for request_id)
|
||||
|
||||
Returns:
|
||||
Response object with unified format
|
||||
"""
|
||||
response_data = {
|
||||
'success': True,
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
response_data['data'] = data
|
||||
|
||||
if message:
|
||||
response_data['message'] = message
|
||||
|
||||
# Add request_id if request is provided
|
||||
if request:
|
||||
response_data['request_id'] = get_request_id(request)
|
||||
|
||||
return Response(response_data, status=status_code)
|
||||
|
||||
|
||||
def error_response(error=None, errors=None, status_code=status.HTTP_400_BAD_REQUEST, request=None, debug_info=None):
|
||||
"""
|
||||
Create a standardized error response
|
||||
|
||||
Args:
|
||||
error: Top-level error message
|
||||
errors: Field-specific errors (dict of field -> list of errors)
|
||||
status_code: HTTP status code (default: 400)
|
||||
request: Request object (optional, for request_id)
|
||||
debug_info: Debug information (only in DEBUG mode)
|
||||
|
||||
Returns:
|
||||
Response object with unified error format
|
||||
"""
|
||||
response_data = {
|
||||
'success': False,
|
||||
}
|
||||
|
||||
if error:
|
||||
response_data['error'] = error
|
||||
elif status_code == status.HTTP_400_BAD_REQUEST:
|
||||
response_data['error'] = 'Bad request'
|
||||
elif status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
response_data['error'] = 'Authentication required'
|
||||
elif status_code == status.HTTP_403_FORBIDDEN:
|
||||
response_data['error'] = 'Permission denied'
|
||||
elif status_code == status.HTTP_404_NOT_FOUND:
|
||||
response_data['error'] = 'Resource not found'
|
||||
elif status_code == status.HTTP_409_CONFLICT:
|
||||
response_data['error'] = 'Conflict'
|
||||
elif status_code == status.HTTP_422_UNPROCESSABLE_ENTITY:
|
||||
response_data['error'] = 'Validation failed'
|
||||
elif status_code == status.HTTP_429_TOO_MANY_REQUESTS:
|
||||
response_data['error'] = 'Rate limit exceeded'
|
||||
elif status_code >= 500:
|
||||
response_data['error'] = 'Internal server error'
|
||||
else:
|
||||
response_data['error'] = 'An error occurred'
|
||||
|
||||
if errors:
|
||||
response_data['errors'] = errors
|
||||
|
||||
# Add request_id if request is provided
|
||||
if request:
|
||||
response_data['request_id'] = get_request_id(request)
|
||||
|
||||
# Add debug info in DEBUG mode
|
||||
if debug_info:
|
||||
response_data['debug'] = debug_info
|
||||
|
||||
return Response(response_data, status=status_code)
|
||||
|
||||
|
||||
def paginated_response(paginated_data, message=None, request=None):
|
||||
"""
|
||||
Create a standardized paginated response
|
||||
|
||||
Args:
|
||||
paginated_data: Paginated data dict from DRF paginator (contains count, next, previous, results)
|
||||
message: Optional success message
|
||||
request: Request object (optional, for request_id)
|
||||
|
||||
Returns:
|
||||
Response object with unified paginated format
|
||||
"""
|
||||
response_data = {
|
||||
'success': True,
|
||||
}
|
||||
|
||||
# Copy pagination fields from DRF paginator
|
||||
if isinstance(paginated_data, dict):
|
||||
response_data.update({
|
||||
'count': paginated_data.get('count', 0),
|
||||
'next': paginated_data.get('next'),
|
||||
'previous': paginated_data.get('previous'),
|
||||
'results': paginated_data.get('results', [])
|
||||
})
|
||||
else:
|
||||
# Fallback if paginated_data is not a dict
|
||||
response_data['count'] = 0
|
||||
response_data['next'] = None
|
||||
response_data['previous'] = None
|
||||
response_data['results'] = []
|
||||
|
||||
if message:
|
||||
response_data['message'] = message
|
||||
|
||||
# Add request_id if request is provided
|
||||
if request:
|
||||
response_data['request_id'] = get_request_id(request)
|
||||
|
||||
return Response(response_data, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
136
backend/igny8_core/api/throttles.py
Normal file
136
backend/igny8_core/api/throttles.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Scoped Rate Throttling
|
||||
Provides rate limiting with different scopes for different operation types
|
||||
"""
|
||||
from rest_framework.throttling import ScopedRateThrottle
|
||||
from django.conf import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DebugScopedRateThrottle(ScopedRateThrottle):
|
||||
"""
|
||||
Scoped rate throttle that can be bypassed in debug mode
|
||||
|
||||
Usage:
|
||||
class MyViewSet(viewsets.ModelViewSet):
|
||||
throttle_scope = 'planner'
|
||||
throttle_classes = [DebugScopedRateThrottle]
|
||||
"""
|
||||
|
||||
def allow_request(self, request, view):
|
||||
"""
|
||||
Check if request should be throttled
|
||||
|
||||
Bypasses throttling if:
|
||||
- DEBUG mode is True
|
||||
- IGNY8_DEBUG_THROTTLE environment variable is True
|
||||
- User belongs to aws-admin or other system accounts
|
||||
- User is admin/developer role
|
||||
"""
|
||||
# Check if throttling should be bypassed
|
||||
debug_bypass = getattr(settings, 'DEBUG', False)
|
||||
env_bypass = getattr(settings, 'IGNY8_DEBUG_THROTTLE', False)
|
||||
|
||||
# Bypass for system account users (aws-admin, default-account, etc.)
|
||||
system_account_bypass = False
|
||||
if hasattr(request, 'user') and request.user and hasattr(request.user, 'is_authenticated') and request.user.is_authenticated:
|
||||
try:
|
||||
# Check if user is in system account (aws-admin, default-account, default)
|
||||
if hasattr(request.user, 'is_system_account_user') and request.user.is_system_account_user():
|
||||
system_account_bypass = True
|
||||
# Also bypass for admin/developer roles
|
||||
elif hasattr(request.user, 'is_admin_or_developer') and request.user.is_admin_or_developer():
|
||||
system_account_bypass = True
|
||||
except (AttributeError, Exception):
|
||||
# If checking fails, continue with normal throttling
|
||||
pass
|
||||
|
||||
if debug_bypass or env_bypass or system_account_bypass:
|
||||
# In debug mode or for system accounts, still set throttle headers but don't actually throttle
|
||||
# This allows testing throttle headers without blocking requests
|
||||
if hasattr(self, 'get_rate'):
|
||||
# Set headers for debugging
|
||||
self.scope = getattr(view, 'throttle_scope', None)
|
||||
if self.scope:
|
||||
# Get rate for this scope
|
||||
rate = self.get_rate()
|
||||
if rate:
|
||||
# Parse rate (e.g., "10/min")
|
||||
num_requests, duration = self.parse_rate(rate)
|
||||
# Set headers
|
||||
request._throttle_debug_info = {
|
||||
'scope': self.scope,
|
||||
'rate': rate,
|
||||
'limit': num_requests,
|
||||
'duration': duration
|
||||
}
|
||||
return True
|
||||
|
||||
# Normal throttling behavior
|
||||
return super().allow_request(request, view)
|
||||
|
||||
def get_rate(self):
|
||||
"""
|
||||
Get rate for the current scope
|
||||
"""
|
||||
if not self.scope:
|
||||
return None
|
||||
|
||||
# Get throttle rates from settings
|
||||
throttle_rates = getattr(settings, 'REST_FRAMEWORK', {}).get('DEFAULT_THROTTLE_RATES', {})
|
||||
|
||||
# Get rate for this scope
|
||||
rate = throttle_rates.get(self.scope)
|
||||
|
||||
# Fallback to default if scope not found
|
||||
if not rate:
|
||||
rate = throttle_rates.get('default', '100/min')
|
||||
|
||||
return rate
|
||||
|
||||
def parse_rate(self, rate):
|
||||
"""
|
||||
Parse rate string (e.g., "10/min") into (num_requests, duration)
|
||||
|
||||
Returns:
|
||||
tuple: (num_requests, duration_in_seconds)
|
||||
"""
|
||||
if not rate:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
num, period = rate.split('/')
|
||||
num_requests = int(num)
|
||||
|
||||
# Parse duration
|
||||
period = period.strip().lower()
|
||||
if period == 'sec' or period == 's':
|
||||
duration = 1
|
||||
elif period == 'min' or period == 'm':
|
||||
duration = 60
|
||||
elif period == 'hour' or period == 'h':
|
||||
duration = 3600
|
||||
elif period == 'day' or period == 'd':
|
||||
duration = 86400
|
||||
else:
|
||||
# Default to seconds
|
||||
duration = 1
|
||||
|
||||
return num_requests, duration
|
||||
except (ValueError, AttributeError):
|
||||
# Invalid rate format, default to 100/min
|
||||
logger.warning(f"Invalid rate format: {rate}, defaulting to 100/min")
|
||||
return 100, 60
|
||||
|
||||
def throttle_success(self):
|
||||
"""
|
||||
Called when request is allowed
|
||||
Sets throttle headers on response
|
||||
"""
|
||||
# This is called by DRF after allow_request returns True
|
||||
# Headers are set automatically by ScopedRateThrottle
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user