322 lines
12 KiB
Python
322 lines
12 KiB
Python
"""
|
|
Authentication URL Configuration
|
|
"""
|
|
from django.urls import path, include
|
|
from django.views.decorators.csrf import csrf_exempt
|
|
from rest_framework.routers import DefaultRouter
|
|
from rest_framework.views import APIView
|
|
from rest_framework.response import Response
|
|
from rest_framework import status, permissions
|
|
from drf_spectacular.utils import extend_schema
|
|
from igny8_core.api.response import success_response, error_response
|
|
from .views import (
|
|
GroupsViewSet, UsersViewSet, AccountsViewSet, SubscriptionsViewSet,
|
|
SiteUserAccessViewSet, PlanViewSet, SiteViewSet, SectorViewSet,
|
|
IndustryViewSet, SeedKeywordViewSet
|
|
)
|
|
from .serializers import RegisterSerializer, LoginSerializer, ChangePasswordSerializer, UserSerializer, RefreshTokenSerializer
|
|
from .models import User
|
|
from .utils import generate_access_token, get_token_expiry, decode_token
|
|
import jwt
|
|
|
|
router = DefaultRouter()
|
|
# Main structure: Groups, Users, Accounts, Subscriptions, Site User Access
|
|
router.register(r'groups', GroupsViewSet, basename='group')
|
|
router.register(r'users', UsersViewSet, basename='user')
|
|
router.register(r'accounts', AccountsViewSet, basename='account')
|
|
router.register(r'subscriptions', SubscriptionsViewSet, basename='subscription')
|
|
router.register(r'site-access', SiteUserAccessViewSet, basename='site-access')
|
|
|
|
# Supporting viewsets
|
|
router.register(r'plans', PlanViewSet, basename='plan')
|
|
router.register(r'sites', SiteViewSet, basename='site')
|
|
router.register(r'sectors', SectorViewSet, basename='sector')
|
|
router.register(r'industries', IndustryViewSet, basename='industry')
|
|
router.register(r'seed-keywords', SeedKeywordViewSet, basename='seed-keyword')
|
|
# Note: AuthViewSet removed - using direct APIView endpoints instead (login, register, etc.)
|
|
|
|
|
|
@extend_schema(
|
|
tags=['Authentication'],
|
|
summary='User Registration',
|
|
description='Register a new user account'
|
|
)
|
|
class RegisterView(APIView):
|
|
"""Registration endpoint."""
|
|
permission_classes = [permissions.AllowAny]
|
|
|
|
def post(self, request):
|
|
from .utils import generate_access_token, generate_refresh_token, get_token_expiry
|
|
from django.contrib.auth import login
|
|
from django.utils import timezone
|
|
|
|
serializer = RegisterSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
user = serializer.save()
|
|
|
|
# Log the user in (create session for session authentication)
|
|
login(request, user)
|
|
|
|
# Get account from user
|
|
account = getattr(user, 'account', None)
|
|
|
|
# Generate JWT tokens
|
|
access_token = generate_access_token(user, account)
|
|
refresh_token = generate_refresh_token(user, account)
|
|
access_expires_at = timezone.now() + get_token_expiry('access')
|
|
refresh_expires_at = timezone.now() + get_token_expiry('refresh')
|
|
|
|
user_serializer = UserSerializer(user)
|
|
return success_response(
|
|
data={
|
|
'user': user_serializer.data,
|
|
'tokens': {
|
|
'access': access_token,
|
|
'refresh': refresh_token,
|
|
'access_expires_at': access_expires_at.isoformat(),
|
|
'refresh_expires_at': refresh_expires_at.isoformat(),
|
|
}
|
|
},
|
|
message='Registration successful',
|
|
status_code=status.HTTP_201_CREATED,
|
|
request=request
|
|
)
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema(
|
|
tags=['Authentication'],
|
|
summary='User Login',
|
|
description='Authenticate user and receive JWT tokens'
|
|
)
|
|
class LoginView(APIView):
|
|
"""Login endpoint."""
|
|
permission_classes = [permissions.AllowAny]
|
|
|
|
def post(self, request):
|
|
serializer = LoginSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
email = serializer.validated_data['email']
|
|
password = serializer.validated_data['password']
|
|
remember_me = serializer.validated_data.get('remember_me', False)
|
|
|
|
try:
|
|
user = User.objects.select_related('account', 'account__plan').get(email=email)
|
|
except User.DoesNotExist:
|
|
return error_response(
|
|
error='Invalid credentials',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
if user.check_password(password):
|
|
# Log the user in (create session for session authentication)
|
|
from django.contrib.auth import login
|
|
login(request, user)
|
|
|
|
# Get account from user
|
|
account = getattr(user, 'account', None)
|
|
|
|
# Generate JWT tokens
|
|
from .utils import generate_access_token, generate_refresh_token, get_access_token_expiry, get_token_expiry
|
|
from django.utils import timezone
|
|
access_token = generate_access_token(user, account, remember_me=remember_me)
|
|
refresh_token = generate_refresh_token(user, account)
|
|
access_expires_at = timezone.now() + get_access_token_expiry(remember_me=remember_me)
|
|
refresh_expires_at = timezone.now() + get_token_expiry('refresh')
|
|
|
|
# Serialize user data safely, handling missing account relationship
|
|
try:
|
|
user_serializer = UserSerializer(user)
|
|
user_data = user_serializer.data
|
|
except Exception as e:
|
|
# Fallback if serializer fails (e.g., missing account_id column)
|
|
# Log the error for debugging but don't fail the login
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.warning(f"UserSerializer failed for user {user.id}: {e}", exc_info=True)
|
|
|
|
# Ensure username is properly set (use email prefix if username is empty/default)
|
|
username = user.username if user.username and user.username != 'user' else user.email.split('@')[0]
|
|
|
|
user_data = {
|
|
'id': user.id,
|
|
'username': username,
|
|
'email': user.email,
|
|
'role': user.role,
|
|
'account': None,
|
|
'accessible_sites': [],
|
|
}
|
|
|
|
return success_response(
|
|
data={
|
|
'user': user_data,
|
|
'access': access_token,
|
|
'refresh': refresh_token,
|
|
'access_expires_at': access_expires_at.isoformat(),
|
|
'refresh_expires_at': refresh_expires_at.isoformat(),
|
|
},
|
|
message='Login successful',
|
|
request=request
|
|
)
|
|
|
|
return error_response(
|
|
error='Invalid credentials',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema(
|
|
tags=['Authentication'],
|
|
summary='Change Password',
|
|
description='Change user password'
|
|
)
|
|
class ChangePasswordView(APIView):
|
|
"""Change password endpoint."""
|
|
permission_classes = [permissions.IsAuthenticated]
|
|
|
|
def post(self, request):
|
|
serializer = ChangePasswordSerializer(data=request.data, context={'request': request})
|
|
if serializer.is_valid():
|
|
user = request.user
|
|
if not user.check_password(serializer.validated_data['old_password']):
|
|
return error_response(
|
|
error='Current password is incorrect',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
user.set_password(serializer.validated_data['new_password'])
|
|
user.save()
|
|
|
|
return success_response(
|
|
message='Password changed successfully',
|
|
request=request
|
|
)
|
|
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema(
|
|
tags=['Authentication'],
|
|
summary='Refresh Token',
|
|
description='Refresh access token using refresh token'
|
|
)
|
|
class RefreshTokenView(APIView):
|
|
"""Refresh access token endpoint."""
|
|
permission_classes = [permissions.AllowAny]
|
|
|
|
def post(self, request):
|
|
serializer = RefreshTokenSerializer(data=request.data)
|
|
if not serializer.is_valid():
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
refresh_token = serializer.validated_data['refresh']
|
|
|
|
try:
|
|
# Decode and validate refresh token
|
|
payload = decode_token(refresh_token)
|
|
|
|
# Verify it's a refresh token
|
|
if payload.get('type') != 'refresh':
|
|
return error_response(
|
|
error='Invalid token type',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Get user
|
|
user_id = payload.get('user_id')
|
|
account_id = payload.get('account_id')
|
|
|
|
try:
|
|
user = User.objects.select_related('account', 'account__plan').get(id=user_id)
|
|
except User.DoesNotExist:
|
|
return error_response(
|
|
error='User not found',
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
# Get account
|
|
account = None
|
|
if account_id:
|
|
try:
|
|
from .models import Account
|
|
account = Account.objects.get(id=account_id)
|
|
except Exception:
|
|
pass
|
|
|
|
if not account:
|
|
account = getattr(user, 'account', None)
|
|
|
|
# Generate new access token
|
|
access_token = generate_access_token(user, account)
|
|
access_expires_at = get_token_expiry('access')
|
|
|
|
return success_response(
|
|
data={
|
|
'access': access_token,
|
|
'access_expires_at': access_expires_at.isoformat()
|
|
},
|
|
request=request
|
|
)
|
|
|
|
except jwt.InvalidTokenError:
|
|
return error_response(
|
|
error='Invalid or expired refresh token',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema(exclude=True) # Exclude from public API documentation - internal authenticated endpoint
|
|
class MeView(APIView):
|
|
"""Get current user information."""
|
|
permission_classes = [permissions.IsAuthenticated]
|
|
|
|
def get(self, request):
|
|
# Refresh user from DB to get latest account/plan data
|
|
# This ensures account/plan changes are reflected immediately
|
|
from .models import User as UserModel
|
|
user = UserModel.objects.select_related('account', 'account__plan').get(id=request.user.id)
|
|
serializer = UserSerializer(user)
|
|
return success_response(
|
|
data={'user': serializer.data},
|
|
request=request
|
|
)
|
|
|
|
|
|
urlpatterns = [
|
|
path('', include(router.urls)),
|
|
path('register/', csrf_exempt(RegisterView.as_view()), name='auth-register'),
|
|
path('login/', csrf_exempt(LoginView.as_view()), name='auth-login'),
|
|
path('refresh/', csrf_exempt(RefreshTokenView.as_view()), name='auth-refresh'),
|
|
path('change-password/', ChangePasswordView.as_view(), name='auth-change-password'),
|
|
path('me/', MeView.as_view(), name='auth-me'),
|
|
]
|
|
|