asdasd
This commit is contained in:
401
tenant-temp/backend/igny8_core/api/base.py
Normal file
401
tenant-temp/backend/igny8_core/api/base.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Base ViewSet with account filtering support
|
||||
Unified API Standard v1.0 compliant
|
||||
"""
|
||||
from rest_framework import viewsets, status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.exceptions import ValidationError as DRFValidationError
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from .response import success_response, error_response
|
||||
|
||||
|
||||
class AccountModelViewSet(viewsets.ModelViewSet):
|
||||
"""
|
||||
Base ViewSet that automatically filters by account.
|
||||
All module ViewSets should inherit from this.
|
||||
"""
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
# Filter by account if model has account field
|
||||
if hasattr(queryset.model, 'account'):
|
||||
user = getattr(self.request, 'user', None)
|
||||
|
||||
if user and hasattr(user, 'is_authenticated') and user.is_authenticated:
|
||||
try:
|
||||
account = getattr(self.request, 'account', None)
|
||||
if not account and hasattr(self.request, 'user') and self.request.user and hasattr(self.request.user, 'is_authenticated') and self.request.user.is_authenticated:
|
||||
user_account = getattr(self.request.user, 'account', None)
|
||||
if user_account:
|
||||
account = user_account
|
||||
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
else:
|
||||
# No account context -> block access
|
||||
return queryset.none()
|
||||
except (AttributeError, TypeError):
|
||||
# If there's an error accessing user attributes, return empty queryset
|
||||
return queryset.none()
|
||||
else:
|
||||
# Require authentication - return empty queryset for unauthenticated users
|
||||
return queryset.none()
|
||||
return queryset
|
||||
|
||||
def perform_create(self, serializer):
|
||||
# Set account from request (set by middleware)
|
||||
account = getattr(self.request, 'account', None)
|
||||
if not account and hasattr(self.request, 'user') and self.request.user and self.request.user.is_authenticated:
|
||||
try:
|
||||
account = getattr(self.request.user, 'account', None)
|
||||
except (AttributeError, Exception):
|
||||
account = None
|
||||
|
||||
if hasattr(serializer.Meta.model, 'account'):
|
||||
if not account:
|
||||
raise PermissionDenied("Account context is required to create this object.")
|
||||
serializer.save(account=account)
|
||||
else:
|
||||
serializer.save()
|
||||
|
||||
def get_serializer_context(self):
|
||||
context = super().get_serializer_context()
|
||||
# Add account to context for serializers
|
||||
account = getattr(self.request, 'account', None)
|
||||
if account:
|
||||
context['account'] = account
|
||||
return context
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
"""
|
||||
Override retrieve to return unified format
|
||||
"""
|
||||
try:
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
return success_response(data=serializer.data, request=request)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
error=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
request=request
|
||||
)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
"""
|
||||
Override create to return unified format
|
||||
"""
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return success_response(
|
||||
data=serializer.data,
|
||||
message='Created successfully',
|
||||
request=request,
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
except DRFValidationError as e:
|
||||
return error_response(
|
||||
error='Validation error',
|
||||
errors=e.detail if hasattr(e, 'detail') else str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request=request
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error in create method: {str(e)}", exc_info=True)
|
||||
# Check if it's a validation-related error
|
||||
if 'required' in str(e).lower() or 'invalid' in str(e).lower() or 'validation' in str(e).lower():
|
||||
return error_response(
|
||||
error='Validation error',
|
||||
errors=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request=request
|
||||
)
|
||||
# For other errors, return 500
|
||||
return error_response(
|
||||
error=f'Internal server error: {str(e)}',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
request=request
|
||||
)
|
||||
|
||||
def update(self, request, *args, **kwargs):
|
||||
"""
|
||||
Override update to return unified format
|
||||
"""
|
||||
partial = kwargs.pop('partial', False)
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
return success_response(
|
||||
data=serializer.data,
|
||||
message='Updated successfully',
|
||||
request=request
|
||||
)
|
||||
except DRFValidationError as e:
|
||||
return error_response(
|
||||
error='Validation error',
|
||||
errors=e.detail if hasattr(e, 'detail') else str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request=request
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error in create method: {str(e)}", exc_info=True)
|
||||
# Check if it's a validation-related error
|
||||
if 'required' in str(e).lower() or 'invalid' in str(e).lower() or 'validation' in str(e).lower():
|
||||
return error_response(
|
||||
error='Validation error',
|
||||
errors=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request=request
|
||||
)
|
||||
# For other errors, return 500
|
||||
return error_response(
|
||||
error=f'Internal server error: {str(e)}',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
request=request
|
||||
)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
"""
|
||||
Override destroy to return unified format
|
||||
"""
|
||||
try:
|
||||
instance = self.get_object()
|
||||
# Protect system account
|
||||
if hasattr(instance, 'slug') and getattr(instance, 'slug', '') == 'aws-admin':
|
||||
from django.core.exceptions import PermissionDenied
|
||||
raise PermissionDenied("System account cannot be deleted.")
|
||||
|
||||
if hasattr(instance, 'soft_delete'):
|
||||
user = getattr(request, 'user', None)
|
||||
retention_days = None
|
||||
account = getattr(instance, 'account', None)
|
||||
if account and hasattr(account, 'deletion_retention_days'):
|
||||
retention_days = account.deletion_retention_days
|
||||
elif hasattr(instance, 'deletion_retention_days'):
|
||||
retention_days = getattr(instance, 'deletion_retention_days', None)
|
||||
instance.soft_delete(
|
||||
user=user if getattr(user, 'is_authenticated', False) else None,
|
||||
retention_days=retention_days,
|
||||
reason='api_delete'
|
||||
)
|
||||
else:
|
||||
self.perform_destroy(instance)
|
||||
return success_response(
|
||||
data=None,
|
||||
message='Deleted successfully',
|
||||
request=request,
|
||||
status_code=status.HTTP_204_NO_CONTENT
|
||||
)
|
||||
except Exception as e:
|
||||
return error_response(
|
||||
error=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
request=request
|
||||
)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""
|
||||
Override list to return unified format
|
||||
"""
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
# Check if pagination is enabled
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
# Use paginator's get_paginated_response which already returns unified format
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
# No pagination - return all results in unified format
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return success_response(
|
||||
data=serializer.data,
|
||||
request=request
|
||||
)
|
||||
|
||||
|
||||
class SiteSectorModelViewSet(AccountModelViewSet):
|
||||
"""
|
||||
Base ViewSet for models that belong to Site and Sector (Keywords, Clusters, etc.).
|
||||
Automatically filters by:
|
||||
1. Account (via parent class)
|
||||
2. User's accessible sites (based on role and SiteUserAccess)
|
||||
3. Optional site/sector query parameters
|
||||
"""
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
|
||||
# Check if model has site and sector fields (SiteSectorBaseModel)
|
||||
if hasattr(queryset.model, 'site') and hasattr(queryset.model, 'sector'):
|
||||
user = getattr(self.request, 'user', None)
|
||||
|
||||
# Check if user is authenticated and is a proper User instance (not AnonymousUser)
|
||||
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and hasattr(user, 'get_accessible_sites'):
|
||||
try:
|
||||
# Get user's accessible sites
|
||||
accessible_sites = user.get_accessible_sites()
|
||||
|
||||
# If no accessible sites, return empty queryset
|
||||
if not accessible_sites.exists():
|
||||
queryset = queryset.none()
|
||||
else:
|
||||
# Filter by accessible sites
|
||||
queryset = queryset.filter(site__in=accessible_sites)
|
||||
except (AttributeError, TypeError):
|
||||
# If there's an error accessing user attributes, return empty queryset
|
||||
queryset = queryset.none()
|
||||
else:
|
||||
# Require authentication - return empty queryset for unauthenticated users
|
||||
queryset = queryset.none()
|
||||
|
||||
# Optional: Filter by specific site (from query params)
|
||||
# Safely access query_params (DRF wraps request with Request class)
|
||||
try:
|
||||
query_params = getattr(self.request, 'query_params', None)
|
||||
if query_params is None:
|
||||
# Fallback for non-DRF requests
|
||||
query_params = getattr(self.request, 'GET', {})
|
||||
site_id = query_params.get('site_id') or query_params.get('site')
|
||||
else:
|
||||
site_id = query_params.get('site_id') or query_params.get('site')
|
||||
except AttributeError:
|
||||
site_id = None
|
||||
|
||||
if site_id:
|
||||
try:
|
||||
# Convert site_id to int if it's a string
|
||||
site_id_int = int(site_id) if site_id else None
|
||||
if site_id_int:
|
||||
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and hasattr(user, 'get_accessible_sites'):
|
||||
try:
|
||||
accessible_sites = user.get_accessible_sites()
|
||||
if accessible_sites.filter(id=site_id_int).exists():
|
||||
queryset = queryset.filter(site_id=site_id_int)
|
||||
else:
|
||||
queryset = queryset.none() # Site not accessible
|
||||
except (AttributeError, TypeError):
|
||||
# If there's an error accessing user attributes, return empty queryset
|
||||
queryset = queryset.none()
|
||||
else:
|
||||
# Require authentication for site filtering
|
||||
queryset = queryset.none()
|
||||
except (ValueError, TypeError):
|
||||
# Invalid site_id, return empty queryset
|
||||
queryset = queryset.none()
|
||||
|
||||
# Optional: Filter by specific sector (from query params)
|
||||
# Safely access query_params (DRF wraps request with Request class)
|
||||
try:
|
||||
query_params = getattr(self.request, 'query_params', None)
|
||||
if query_params is None:
|
||||
# Fallback for non-DRF requests
|
||||
query_params = getattr(self.request, 'GET', {})
|
||||
sector_id = query_params.get('sector_id')
|
||||
else:
|
||||
sector_id = query_params.get('sector_id')
|
||||
except AttributeError:
|
||||
sector_id = None
|
||||
|
||||
if sector_id:
|
||||
try:
|
||||
# Convert sector_id to int if it's a string
|
||||
sector_id_int = int(sector_id) if sector_id else None
|
||||
if sector_id_int:
|
||||
queryset = queryset.filter(sector_id=sector_id_int)
|
||||
# If site_id also provided, ensure sector belongs to that site
|
||||
if site_id:
|
||||
try:
|
||||
site_id_int = int(site_id) if site_id else None
|
||||
if site_id_int:
|
||||
queryset = queryset.filter(site_id=site_id_int)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
except (ValueError, TypeError):
|
||||
# Invalid sector_id, return empty queryset
|
||||
queryset = queryset.none()
|
||||
|
||||
# Always exclude records where site is null (orphaned records)
|
||||
# This prevents showing keywords/clusters/ideas/tasks that aren't associated with any site
|
||||
# Only skip this if explicitly requested (e.g., for admin cleanup operations)
|
||||
queryset = queryset.exclude(site__isnull=True)
|
||||
|
||||
return queryset
|
||||
|
||||
def perform_create(self, serializer):
|
||||
# First call parent to set account
|
||||
super().perform_create(serializer)
|
||||
|
||||
# If model has site and sector fields, validate access
|
||||
if hasattr(serializer.Meta.model, 'site') and hasattr(serializer.Meta.model, 'sector'):
|
||||
user = getattr(self.request, 'user', None)
|
||||
site = serializer.validated_data.get('site')
|
||||
sector = serializer.validated_data.get('sector')
|
||||
|
||||
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and site:
|
||||
try:
|
||||
if hasattr(user, 'get_accessible_sites'):
|
||||
accessible_sites = user.get_accessible_sites()
|
||||
if not accessible_sites.filter(id=site.id).exists():
|
||||
raise PermissionDenied("You do not have access to this site")
|
||||
|
||||
# Verify sector belongs to site
|
||||
if sector and hasattr(sector, 'site') and sector.site != site:
|
||||
raise PermissionDenied("Sector must belong to the selected site")
|
||||
except (AttributeError, TypeError) as e:
|
||||
# If there's an error accessing user attributes, raise permission denied
|
||||
raise PermissionDenied("Unable to verify access permissions")
|
||||
|
||||
def get_serializer_context(self):
|
||||
context = super().get_serializer_context()
|
||||
user = getattr(self.request, 'user', None)
|
||||
|
||||
# Add accessible sites to context for serializer (e.g., for dropdown choices)
|
||||
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and hasattr(user, 'get_accessible_sites'):
|
||||
try:
|
||||
context['accessible_sites'] = user.get_accessible_sites()
|
||||
# Get accessible sectors from accessible sites
|
||||
from igny8_core.auth.models import Sector
|
||||
context['accessible_sectors'] = Sector.objects.filter(
|
||||
site__in=context['accessible_sites'],
|
||||
is_active=True
|
||||
)
|
||||
except (AttributeError, TypeError):
|
||||
# If there's an error, set empty querysets
|
||||
from igny8_core.auth.models import Site, Sector
|
||||
context['accessible_sites'] = Site.objects.none()
|
||||
context['accessible_sectors'] = Sector.objects.none()
|
||||
else:
|
||||
# Set empty querysets for unauthenticated users
|
||||
from igny8_core.auth.models import Site, Sector
|
||||
context['accessible_sites'] = Site.objects.none()
|
||||
context['accessible_sectors'] = Sector.objects.none()
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class StandardResponseMixin:
|
||||
"""
|
||||
Mixin for standard API response format.
|
||||
"""
|
||||
def get_response(self, data, message=None, status_code=200):
|
||||
return Response({
|
||||
'success': True,
|
||||
'message': message,
|
||||
'data': data
|
||||
}, status=status_code)
|
||||
|
||||
def get_error_response(self, message, errors=None, status_code=400):
|
||||
return Response({
|
||||
'success': False,
|
||||
'message': message,
|
||||
'errors': errors
|
||||
}, status=status_code)
|
||||
|
||||
131
tenant-temp/backend/igny8_core/api/permissions.py
Normal file
131
tenant-temp/backend/igny8_core/api/permissions.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Standardized Permission Classes
|
||||
Provides consistent permission checking across all endpoints
|
||||
"""
|
||||
from rest_framework import permissions
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
|
||||
class IsAuthenticatedAndActive(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires user to be authenticated and active
|
||||
Base permission for most endpoints
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Check if user is active
|
||||
if hasattr(request.user, 'is_active'):
|
||||
return request.user.is_active
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class HasTenantAccess(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires user to belong to the tenant/account
|
||||
Ensures tenant isolation
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Get account from request (set by middleware)
|
||||
account = getattr(request, 'account', None)
|
||||
|
||||
# If no account in request, try to get from user
|
||||
if not account and hasattr(request.user, 'account'):
|
||||
try:
|
||||
account = request.user.account
|
||||
except (AttributeError, Exception):
|
||||
pass
|
||||
|
||||
# Regular users must have account access
|
||||
if account:
|
||||
# Check if user belongs to this account
|
||||
if hasattr(request.user, 'account'):
|
||||
try:
|
||||
user_account = request.user.account
|
||||
return user_account == account or user_account.id == account.id
|
||||
except (AttributeError, Exception):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class IsViewerOrAbove(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires viewer, editor, admin, or owner role
|
||||
For read-only operations
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Check user role
|
||||
if hasattr(request.user, 'role'):
|
||||
role = request.user.role
|
||||
# viewer, editor, admin, owner all have access
|
||||
return role in ['viewer', 'editor', 'admin', 'owner']
|
||||
|
||||
# If no role system, allow authenticated users
|
||||
return True
|
||||
|
||||
|
||||
class IsEditorOrAbove(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires editor, admin, or owner role
|
||||
For content operations
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Check user role
|
||||
if hasattr(request.user, 'role'):
|
||||
role = request.user.role
|
||||
# editor, admin, owner have access
|
||||
return role in ['editor', 'admin', 'owner']
|
||||
|
||||
# If no role system, allow authenticated users
|
||||
return True
|
||||
|
||||
|
||||
class IsAdminOrOwner(permissions.BasePermission):
|
||||
"""
|
||||
Permission class that requires admin or owner role only
|
||||
For settings, keys, billing operations
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
if not request.user or not request.user.is_authenticated:
|
||||
return False
|
||||
|
||||
# Check user role
|
||||
if hasattr(request.user, 'role'):
|
||||
role = request.user.role
|
||||
# admin, owner have access
|
||||
return role in ['admin', 'owner']
|
||||
|
||||
# If no role system, deny by default for security
|
||||
return False
|
||||
|
||||
|
||||
class IsSystemAccountOrDeveloper(permissions.BasePermission):
|
||||
"""
|
||||
Allow only system accounts (aws-admin/default-account/default) or developer role.
|
||||
Use for sensitive, globally-scoped settings like integration API keys.
|
||||
"""
|
||||
def has_permission(self, request, view):
|
||||
user = getattr(request, "user", None)
|
||||
if not user or not user.is_authenticated:
|
||||
return False
|
||||
|
||||
account_slug = getattr(getattr(user, "account", None), "slug", None)
|
||||
if user.role == "developer":
|
||||
return True
|
||||
if account_slug in ["aws-admin", "default-account", "default"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
135
tenant-temp/backend/igny8_core/api/throttles.py
Normal file
135
tenant-temp/backend/igny8_core/api/throttles.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Scoped Rate Throttling
|
||||
Provides rate limiting with different scopes for different operation types
|
||||
"""
|
||||
from rest_framework.throttling import ScopedRateThrottle
|
||||
from django.conf import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DebugScopedRateThrottle(ScopedRateThrottle):
|
||||
"""
|
||||
Scoped rate throttle that can be bypassed in debug mode
|
||||
|
||||
Usage:
|
||||
class MyViewSet(viewsets.ModelViewSet):
|
||||
throttle_scope = 'planner'
|
||||
throttle_classes = [DebugScopedRateThrottle]
|
||||
"""
|
||||
|
||||
def allow_request(self, request, view):
|
||||
"""
|
||||
Check if request should be throttled
|
||||
|
||||
Bypasses throttling if:
|
||||
- DEBUG mode is True
|
||||
- IGNY8_DEBUG_THROTTLE environment variable is True
|
||||
- User belongs to aws-admin or other system accounts
|
||||
- User is admin/developer role
|
||||
- Public blueprint list request with site filter (for Sites Renderer)
|
||||
"""
|
||||
# Check if throttling should be bypassed
|
||||
debug_bypass = getattr(settings, 'DEBUG', False)
|
||||
env_bypass = getattr(settings, 'IGNY8_DEBUG_THROTTLE', False)
|
||||
|
||||
# Bypass for public blueprint list requests (Sites Renderer fallback)
|
||||
public_blueprint_bypass = False
|
||||
if hasattr(view, 'action') and view.action == 'list':
|
||||
if hasattr(request, 'query_params') and request.query_params.get('site'):
|
||||
if not request.user or not hasattr(request.user, 'is_authenticated') or not request.user.is_authenticated:
|
||||
public_blueprint_bypass = True
|
||||
|
||||
# Bypass for authenticated users (avoid user-facing 429s)
|
||||
authenticated_bypass = False
|
||||
if hasattr(request, 'user') and request.user and hasattr(request.user, 'is_authenticated') and request.user.is_authenticated:
|
||||
authenticated_bypass = True # Do not throttle logged-in users
|
||||
|
||||
if debug_bypass or env_bypass or public_blueprint_bypass or authenticated_bypass:
|
||||
# In debug mode or for system accounts, still set throttle headers but don't actually throttle
|
||||
# This allows testing throttle headers without blocking requests
|
||||
if hasattr(self, 'get_rate'):
|
||||
# Set headers for debugging
|
||||
self.scope = getattr(view, 'throttle_scope', None)
|
||||
if self.scope:
|
||||
# Get rate for this scope
|
||||
rate = self.get_rate()
|
||||
if rate:
|
||||
# Parse rate (e.g., "10/min")
|
||||
num_requests, duration = self.parse_rate(rate)
|
||||
# Set headers
|
||||
request._throttle_debug_info = {
|
||||
'scope': self.scope,
|
||||
'rate': rate,
|
||||
'limit': num_requests,
|
||||
'duration': duration
|
||||
}
|
||||
return True
|
||||
|
||||
# Normal throttling behavior
|
||||
return super().allow_request(request, view)
|
||||
|
||||
def get_rate(self):
|
||||
"""
|
||||
Get rate for the current scope
|
||||
"""
|
||||
if not self.scope:
|
||||
return None
|
||||
|
||||
# Get throttle rates from settings
|
||||
throttle_rates = getattr(settings, 'REST_FRAMEWORK', {}).get('DEFAULT_THROTTLE_RATES', {})
|
||||
|
||||
# Get rate for this scope
|
||||
rate = throttle_rates.get(self.scope)
|
||||
|
||||
# Fallback to default if scope not found
|
||||
if not rate:
|
||||
rate = throttle_rates.get('default', '100/min')
|
||||
|
||||
return rate
|
||||
|
||||
def parse_rate(self, rate):
|
||||
"""
|
||||
Parse rate string (e.g., "10/min") into (num_requests, duration)
|
||||
|
||||
Returns:
|
||||
tuple: (num_requests, duration_in_seconds)
|
||||
"""
|
||||
if not rate:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
num, period = rate.split('/')
|
||||
num_requests = int(num)
|
||||
|
||||
# Parse duration
|
||||
period = period.strip().lower()
|
||||
if period == 'sec' or period == 's':
|
||||
duration = 1
|
||||
elif period == 'min' or period == 'm':
|
||||
duration = 60
|
||||
elif period == 'hour' or period == 'h':
|
||||
duration = 3600
|
||||
elif period == 'day' or period == 'd':
|
||||
duration = 86400
|
||||
else:
|
||||
# Default to seconds
|
||||
duration = 1
|
||||
|
||||
return num_requests, duration
|
||||
except (ValueError, AttributeError):
|
||||
# Invalid rate format, default to 100/min
|
||||
logger.warning(f"Invalid rate format: {rate}, defaulting to 100/min")
|
||||
return 100, 60
|
||||
|
||||
def throttle_success(self):
|
||||
"""
|
||||
Called when request is allowed
|
||||
Sets throttle headers on response
|
||||
"""
|
||||
# This is called by DRF after allow_request returns True
|
||||
# Headers are set automatically by ScopedRateThrottle
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user