395 lines
19 KiB
Python
395 lines
19 KiB
Python
"""
|
|
Base ViewSet with account filtering support
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
from rest_framework import viewsets, status
|
|
from rest_framework.response import Response
|
|
from rest_framework.exceptions import ValidationError as DRFValidationError
|
|
from django.core.exceptions import PermissionDenied
|
|
from .response import success_response, error_response
|
|
|
|
|
|
class AccountModelViewSet(viewsets.ModelViewSet):
|
|
"""
|
|
Base ViewSet that automatically filters by account.
|
|
All module ViewSets should inherit from this.
|
|
"""
|
|
def get_queryset(self):
|
|
queryset = super().get_queryset()
|
|
# Filter by account if model has account field
|
|
if hasattr(queryset.model, 'account'):
|
|
user = getattr(self.request, 'user', None)
|
|
|
|
# ADMIN/DEV/SYSTEM ACCOUNT OVERRIDE: Skip account filtering for:
|
|
# - Admins and developers (by role)
|
|
# - Users in system accounts (aws-admin, default-account)
|
|
if user and hasattr(user, 'is_authenticated') and user.is_authenticated:
|
|
try:
|
|
# Check if user has admin/developer privileges
|
|
is_admin_or_dev = (hasattr(user, 'is_admin_or_developer') and user.is_admin_or_developer()) if user else False
|
|
is_system_user = (hasattr(user, 'is_system_account_user') and user.is_system_account_user()) if user else False
|
|
|
|
if is_admin_or_dev or is_system_user:
|
|
# Skip account filtering - allow all accounts
|
|
pass
|
|
else:
|
|
# Get account from request (set by middleware)
|
|
account = getattr(self.request, 'account', None)
|
|
if account:
|
|
queryset = queryset.filter(account=account)
|
|
elif hasattr(self.request, 'user') and self.request.user and hasattr(self.request.user, 'is_authenticated') and self.request.user.is_authenticated:
|
|
# Fallback to user's account
|
|
try:
|
|
user_account = getattr(self.request.user, 'account', None)
|
|
if user_account:
|
|
queryset = queryset.filter(account=user_account)
|
|
except (AttributeError, Exception):
|
|
# If account access fails (e.g., column mismatch), skip account filtering
|
|
pass
|
|
except (AttributeError, TypeError) as e:
|
|
# If there's an error accessing user attributes, return empty queryset
|
|
return queryset.none()
|
|
else:
|
|
# Require authentication - return empty queryset for unauthenticated users
|
|
return queryset.none()
|
|
return queryset
|
|
|
|
def perform_create(self, serializer):
|
|
# Set account from request (set by middleware)
|
|
account = getattr(self.request, 'account', None)
|
|
if not account and hasattr(self.request, 'user') and self.request.user and self.request.user.is_authenticated:
|
|
try:
|
|
account = getattr(self.request.user, 'account', None)
|
|
except (AttributeError, Exception):
|
|
# If account access fails (e.g., column mismatch), set to None
|
|
account = None
|
|
|
|
# If model has account field, set it
|
|
if account and hasattr(serializer.Meta.model, 'account'):
|
|
serializer.save(account=account)
|
|
else:
|
|
serializer.save()
|
|
|
|
def get_serializer_context(self):
|
|
context = super().get_serializer_context()
|
|
# Add account to context for serializers
|
|
account = getattr(self.request, 'account', None)
|
|
if account:
|
|
context['account'] = account
|
|
return context
|
|
|
|
def retrieve(self, request, *args, **kwargs):
|
|
"""
|
|
Override retrieve to return unified format
|
|
"""
|
|
try:
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(instance)
|
|
return success_response(data=serializer.data, request=request)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
"""
|
|
Override create to return unified format
|
|
"""
|
|
serializer = self.get_serializer(data=request.data)
|
|
try:
|
|
serializer.is_valid(raise_exception=True)
|
|
self.perform_create(serializer)
|
|
headers = self.get_success_headers(serializer.data)
|
|
return success_response(
|
|
data=serializer.data,
|
|
message='Created successfully',
|
|
request=request,
|
|
status_code=status.HTTP_201_CREATED
|
|
)
|
|
except DRFValidationError as e:
|
|
return error_response(
|
|
error='Validation error',
|
|
errors=e.detail if hasattr(e, 'detail') else str(e),
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
except Exception as e:
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"Error in create method: {str(e)}", exc_info=True)
|
|
# Check if it's a validation-related error
|
|
if 'required' in str(e).lower() or 'invalid' in str(e).lower() or 'validation' in str(e).lower():
|
|
return error_response(
|
|
error='Validation error',
|
|
errors=str(e),
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
# For other errors, return 500
|
|
return error_response(
|
|
error=f'Internal server error: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
def update(self, request, *args, **kwargs):
|
|
"""
|
|
Override update to return unified format
|
|
"""
|
|
partial = kwargs.pop('partial', False)
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
|
try:
|
|
serializer.is_valid(raise_exception=True)
|
|
self.perform_update(serializer)
|
|
return success_response(
|
|
data=serializer.data,
|
|
message='Updated successfully',
|
|
request=request
|
|
)
|
|
except DRFValidationError as e:
|
|
return error_response(
|
|
error='Validation error',
|
|
errors=e.detail if hasattr(e, 'detail') else str(e),
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
except Exception as e:
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"Error in create method: {str(e)}", exc_info=True)
|
|
# Check if it's a validation-related error
|
|
if 'required' in str(e).lower() or 'invalid' in str(e).lower() or 'validation' in str(e).lower():
|
|
return error_response(
|
|
error='Validation error',
|
|
errors=str(e),
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
# For other errors, return 500
|
|
return error_response(
|
|
error=f'Internal server error: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
def destroy(self, request, *args, **kwargs):
|
|
"""
|
|
Override destroy to return unified format
|
|
"""
|
|
try:
|
|
instance = self.get_object()
|
|
self.perform_destroy(instance)
|
|
return success_response(
|
|
data=None,
|
|
message='Deleted successfully',
|
|
request=request,
|
|
status_code=status.HTTP_204_NO_CONTENT
|
|
)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
|
|
class SiteSectorModelViewSet(AccountModelViewSet):
|
|
"""
|
|
Base ViewSet for models that belong to Site and Sector (Keywords, Clusters, etc.).
|
|
Automatically filters by:
|
|
1. Account (via parent class)
|
|
2. User's accessible sites (based on role and SiteUserAccess)
|
|
3. Optional site/sector query parameters
|
|
"""
|
|
def get_queryset(self):
|
|
queryset = super().get_queryset()
|
|
|
|
# Check if model has site and sector fields (SiteSectorBaseModel)
|
|
if hasattr(queryset.model, 'site') and hasattr(queryset.model, 'sector'):
|
|
user = getattr(self.request, 'user', None)
|
|
|
|
# Check if user is authenticated and is a proper User instance (not AnonymousUser)
|
|
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and hasattr(user, 'get_accessible_sites'):
|
|
try:
|
|
# ADMIN/DEV/SYSTEM ACCOUNT OVERRIDE: Developers, admins, and system account users
|
|
# can see all data regardless of site/sector
|
|
if (hasattr(user, 'is_admin_or_developer') and user.is_admin_or_developer()) or \
|
|
(hasattr(user, 'is_system_account_user') and user.is_system_account_user()):
|
|
# Skip site/sector filtering for admins, developers, and system account users
|
|
# But still respect optional query params if provided
|
|
pass
|
|
else:
|
|
# Get user's accessible sites
|
|
accessible_sites = user.get_accessible_sites()
|
|
|
|
# If no accessible sites, return empty queryset (unless admin/developer/system account)
|
|
if not accessible_sites.exists():
|
|
queryset = queryset.none()
|
|
else:
|
|
# Filter by accessible sites
|
|
queryset = queryset.filter(site__in=accessible_sites)
|
|
except (AttributeError, TypeError) as e:
|
|
# If there's an error accessing user attributes, return empty queryset
|
|
queryset = queryset.none()
|
|
else:
|
|
# Require authentication - return empty queryset for unauthenticated users
|
|
queryset = queryset.none()
|
|
|
|
# Optional: Filter by specific site (from query params)
|
|
# Safely access query_params (DRF wraps request with Request class)
|
|
try:
|
|
query_params = getattr(self.request, 'query_params', None)
|
|
if query_params is None:
|
|
# Fallback for non-DRF requests
|
|
query_params = getattr(self.request, 'GET', {})
|
|
site_id = query_params.get('site_id')
|
|
else:
|
|
site_id = query_params.get('site_id')
|
|
except AttributeError:
|
|
site_id = None
|
|
|
|
if site_id:
|
|
try:
|
|
# Convert site_id to int if it's a string
|
|
site_id_int = int(site_id) if site_id else None
|
|
if site_id_int:
|
|
# ADMIN/DEV/SYSTEM ACCOUNT OVERRIDE: Admins, developers, and system account users
|
|
# can filter by any site, others must verify access
|
|
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and hasattr(user, 'get_accessible_sites'):
|
|
try:
|
|
if (hasattr(user, 'is_admin_or_developer') and user.is_admin_or_developer()) or \
|
|
(hasattr(user, 'is_system_account_user') and user.is_system_account_user()):
|
|
# Admin/Developer/System Account User can filter by any site
|
|
queryset = queryset.filter(site_id=site_id_int)
|
|
else:
|
|
accessible_sites = user.get_accessible_sites()
|
|
if accessible_sites.filter(id=site_id_int).exists():
|
|
queryset = queryset.filter(site_id=site_id_int)
|
|
else:
|
|
queryset = queryset.none() # Site not accessible
|
|
except (AttributeError, TypeError) as e:
|
|
# If there's an error accessing user attributes, return empty queryset
|
|
queryset = queryset.none()
|
|
else:
|
|
# Require authentication for site filtering
|
|
queryset = queryset.none()
|
|
except (ValueError, TypeError):
|
|
# Invalid site_id, return empty queryset
|
|
queryset = queryset.none()
|
|
|
|
# Optional: Filter by specific sector (from query params)
|
|
# Safely access query_params (DRF wraps request with Request class)
|
|
try:
|
|
query_params = getattr(self.request, 'query_params', None)
|
|
if query_params is None:
|
|
# Fallback for non-DRF requests
|
|
query_params = getattr(self.request, 'GET', {})
|
|
sector_id = query_params.get('sector_id')
|
|
else:
|
|
sector_id = query_params.get('sector_id')
|
|
except AttributeError:
|
|
sector_id = None
|
|
|
|
if sector_id:
|
|
try:
|
|
# Convert sector_id to int if it's a string
|
|
sector_id_int = int(sector_id) if sector_id else None
|
|
if sector_id_int:
|
|
queryset = queryset.filter(sector_id=sector_id_int)
|
|
# If site_id also provided, ensure sector belongs to that site
|
|
if site_id:
|
|
try:
|
|
site_id_int = int(site_id) if site_id else None
|
|
if site_id_int:
|
|
queryset = queryset.filter(site_id=site_id_int)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
except (ValueError, TypeError):
|
|
# Invalid sector_id, return empty queryset
|
|
queryset = queryset.none()
|
|
|
|
# Always exclude records where site is null (orphaned records)
|
|
# This prevents showing keywords/clusters/ideas/tasks that aren't associated with any site
|
|
# Only skip this if explicitly requested (e.g., for admin cleanup operations)
|
|
queryset = queryset.exclude(site__isnull=True)
|
|
|
|
return queryset
|
|
|
|
def perform_create(self, serializer):
|
|
# First call parent to set account
|
|
super().perform_create(serializer)
|
|
|
|
# If model has site and sector fields, validate access
|
|
if hasattr(serializer.Meta.model, 'site') and hasattr(serializer.Meta.model, 'sector'):
|
|
user = getattr(self.request, 'user', None)
|
|
site = serializer.validated_data.get('site')
|
|
sector = serializer.validated_data.get('sector')
|
|
|
|
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and site:
|
|
try:
|
|
# ADMIN/DEV/SYSTEM ACCOUNT OVERRIDE: Admins, developers, and system account users
|
|
# can create in any site, others must verify access
|
|
if not ((hasattr(user, 'is_admin_or_developer') and user.is_admin_or_developer()) or
|
|
(hasattr(user, 'is_system_account_user') and user.is_system_account_user())):
|
|
if hasattr(user, 'get_accessible_sites'):
|
|
accessible_sites = user.get_accessible_sites()
|
|
if not accessible_sites.filter(id=site.id).exists():
|
|
raise PermissionDenied("You do not have access to this site")
|
|
|
|
# Verify sector belongs to site
|
|
if sector and hasattr(sector, 'site') and sector.site != site:
|
|
raise PermissionDenied("Sector must belong to the selected site")
|
|
except (AttributeError, TypeError) as e:
|
|
# If there's an error accessing user attributes, raise permission denied
|
|
raise PermissionDenied("Unable to verify access permissions")
|
|
|
|
def get_serializer_context(self):
|
|
context = super().get_serializer_context()
|
|
user = getattr(self.request, 'user', None)
|
|
|
|
# Add accessible sites to context for serializer (e.g., for dropdown choices)
|
|
if user and hasattr(user, 'is_authenticated') and user.is_authenticated and hasattr(user, 'get_accessible_sites'):
|
|
try:
|
|
context['accessible_sites'] = user.get_accessible_sites()
|
|
# Get accessible sectors from accessible sites
|
|
from igny8_core.auth.models import Sector
|
|
context['accessible_sectors'] = Sector.objects.filter(
|
|
site__in=context['accessible_sites'],
|
|
is_active=True
|
|
)
|
|
except (AttributeError, TypeError):
|
|
# If there's an error, set empty querysets
|
|
from igny8_core.auth.models import Site, Sector
|
|
context['accessible_sites'] = Site.objects.none()
|
|
context['accessible_sectors'] = Sector.objects.none()
|
|
else:
|
|
# Set empty querysets for unauthenticated users
|
|
from igny8_core.auth.models import Site, Sector
|
|
context['accessible_sites'] = Site.objects.none()
|
|
context['accessible_sectors'] = Sector.objects.none()
|
|
|
|
return context
|
|
|
|
|
|
class StandardResponseMixin:
|
|
"""
|
|
Mixin for standard API response format.
|
|
"""
|
|
def get_response(self, data, message=None, status_code=200):
|
|
return Response({
|
|
'success': True,
|
|
'message': message,
|
|
'data': data
|
|
}, status=status_code)
|
|
|
|
def get_error_response(self, message, errors=None, status_code=400):
|
|
return Response({
|
|
'success': False,
|
|
'message': message,
|
|
'errors': errors
|
|
}, status=status_code)
|
|
|