Add refresh token functionality and improve login response handling
- Introduced RefreshTokenView to allow users to refresh their access tokens using a valid refresh token. - Enhanced LoginView to ensure correct user/account loading and improved error handling during user serialization. - Updated API response structure to include access and refresh token expiration times. - Adjusted frontend API handling to support both new and legacy token response formats.
This commit is contained in:
@@ -14,8 +14,10 @@ from .views import (
|
||||
SiteUserAccessViewSet, PlanViewSet, SiteViewSet, SectorViewSet,
|
||||
IndustryViewSet, SeedKeywordViewSet
|
||||
)
|
||||
from .serializers import RegisterSerializer, LoginSerializer, ChangePasswordSerializer, UserSerializer
|
||||
from .serializers import RegisterSerializer, LoginSerializer, ChangePasswordSerializer, UserSerializer, RefreshTokenSerializer
|
||||
from .models import User
|
||||
from .utils import generate_access_token, get_token_expiry, decode_token
|
||||
import jwt
|
||||
|
||||
router = DefaultRouter()
|
||||
# Main structure: Groups, Users, Accounts, Subscriptions, Site User Access
|
||||
@@ -78,7 +80,7 @@ class LoginView(APIView):
|
||||
password = serializer.validated_data['password']
|
||||
|
||||
try:
|
||||
user = User.objects.get(email=email)
|
||||
user = User.objects.select_related('account', 'account__plan').get(email=email)
|
||||
except User.DoesNotExist:
|
||||
return error_response(
|
||||
error='Invalid credentials',
|
||||
@@ -107,9 +109,17 @@ class LoginView(APIView):
|
||||
user_data = user_serializer.data
|
||||
except Exception as e:
|
||||
# Fallback if serializer fails (e.g., missing account_id column)
|
||||
# Log the error for debugging but don't fail the login
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"UserSerializer failed for user {user.id}: {e}", exc_info=True)
|
||||
|
||||
# Ensure username is properly set (use email prefix if username is empty/default)
|
||||
username = user.username if user.username and user.username != 'user' else user.email.split('@')[0]
|
||||
|
||||
user_data = {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'username': username,
|
||||
'email': user.email,
|
||||
'role': user.role,
|
||||
'account': None,
|
||||
@@ -119,12 +129,10 @@ class LoginView(APIView):
|
||||
return success_response(
|
||||
data={
|
||||
'user': user_data,
|
||||
'tokens': {
|
||||
'access': access_token,
|
||||
'refresh': refresh_token,
|
||||
'access_expires_at': access_expires_at.isoformat(),
|
||||
'refresh_expires_at': refresh_expires_at.isoformat(),
|
||||
}
|
||||
},
|
||||
message='Login successful',
|
||||
request=request
|
||||
@@ -180,6 +188,84 @@ class ChangePasswordView(APIView):
|
||||
)
|
||||
|
||||
|
||||
@extend_schema(
|
||||
tags=['Authentication'],
|
||||
summary='Refresh Token',
|
||||
description='Refresh access token using refresh token'
|
||||
)
|
||||
class RefreshTokenView(APIView):
|
||||
"""Refresh access token endpoint."""
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
def post(self, request):
|
||||
serializer = RefreshTokenSerializer(data=request.data)
|
||||
if not serializer.is_valid():
|
||||
return error_response(
|
||||
error='Validation failed',
|
||||
errors=serializer.errors,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request=request
|
||||
)
|
||||
|
||||
refresh_token = serializer.validated_data['refresh']
|
||||
|
||||
try:
|
||||
# Decode and validate refresh token
|
||||
payload = decode_token(refresh_token)
|
||||
|
||||
# Verify it's a refresh token
|
||||
if payload.get('type') != 'refresh':
|
||||
return error_response(
|
||||
error='Invalid token type',
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request=request
|
||||
)
|
||||
|
||||
# Get user
|
||||
user_id = payload.get('user_id')
|
||||
account_id = payload.get('account_id')
|
||||
|
||||
try:
|
||||
user = User.objects.select_related('account', 'account__plan').get(id=user_id)
|
||||
except User.DoesNotExist:
|
||||
return error_response(
|
||||
error='User not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
request=request
|
||||
)
|
||||
|
||||
# Get account
|
||||
account = None
|
||||
if account_id:
|
||||
try:
|
||||
from .models import Account
|
||||
account = Account.objects.get(id=account_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not account:
|
||||
account = getattr(user, 'account', None)
|
||||
|
||||
# Generate new access token
|
||||
access_token = generate_access_token(user, account)
|
||||
access_expires_at = get_token_expiry('access')
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'access': access_token,
|
||||
'access_expires_at': access_expires_at.isoformat()
|
||||
},
|
||||
request=request
|
||||
)
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return error_response(
|
||||
error='Invalid or expired refresh token',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
request=request
|
||||
)
|
||||
|
||||
|
||||
@extend_schema(exclude=True) # Exclude from public API documentation - internal authenticated endpoint
|
||||
class MeView(APIView):
|
||||
"""Get current user information."""
|
||||
@@ -201,6 +287,7 @@ urlpatterns = [
|
||||
path('', include(router.urls)),
|
||||
path('register/', csrf_exempt(RegisterView.as_view()), name='auth-register'),
|
||||
path('login/', csrf_exempt(LoginView.as_view()), name='auth-login'),
|
||||
path('refresh/', csrf_exempt(RefreshTokenView.as_view()), name='auth-refresh'),
|
||||
path('change-password/', ChangePasswordView.as_view(), name='auth-change-password'),
|
||||
path('me/', MeView.as_view(), name='auth-me'),
|
||||
]
|
||||
|
||||
@@ -933,12 +933,10 @@ class AuthViewSet(viewsets.GenericViewSet):
|
||||
return success_response(
|
||||
data={
|
||||
'user': user_serializer.data,
|
||||
'tokens': {
|
||||
'access': access_token,
|
||||
'refresh': refresh_token,
|
||||
'access_expires_at': access_expires_at.isoformat(),
|
||||
'refresh_expires_at': refresh_expires_at.isoformat(),
|
||||
}
|
||||
},
|
||||
message='Login successful',
|
||||
request=request
|
||||
|
||||
@@ -54,8 +54,8 @@ class CreditBalanceViewSet(viewsets.ViewSet):
|
||||
request=request
|
||||
)
|
||||
|
||||
# Get plan credits per month
|
||||
plan_credits_per_month = account.plan.credits_per_month if account.plan else 0
|
||||
# Get plan credits per month (use get_effective_credits_per_month for Phase 0 compatibility)
|
||||
plan_credits_per_month = account.plan.get_effective_credits_per_month() if account.plan else 0
|
||||
|
||||
# Calculate credits used this month
|
||||
now = timezone.now()
|
||||
|
||||
@@ -235,6 +235,15 @@ class ModuleSettingsViewSet(AccountModelViewSet):
|
||||
|
||||
def retrieve(self, request, pk=None):
|
||||
"""Get setting by key (pk can be key string)"""
|
||||
# Special case: if pk is "enable", this is likely a routing conflict
|
||||
# The correct endpoint is /settings/modules/enable/ which should go to ModuleEnableSettingsViewSet
|
||||
if pk == 'enable':
|
||||
return error_response(
|
||||
error='Use /api/v1/system/settings/modules/enable/ endpoint for module enable settings',
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
request=request
|
||||
)
|
||||
|
||||
queryset = self.get_queryset()
|
||||
try:
|
||||
# Try to get by ID first
|
||||
@@ -301,7 +310,7 @@ class ModuleEnableSettingsViewSet(AccountModelViewSet):
|
||||
Allow read access to all authenticated users,
|
||||
but restrict write access to admins/owners
|
||||
"""
|
||||
if self.action in ['list', 'retrieve']:
|
||||
if self.action in ['list', 'retrieve', 'get_current']:
|
||||
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess]
|
||||
else:
|
||||
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsAdminOrOwner]
|
||||
@@ -321,6 +330,14 @@ class ModuleEnableSettingsViewSet(AccountModelViewSet):
|
||||
queryset = queryset.filter(account=account)
|
||||
return queryset
|
||||
|
||||
@action(detail=False, methods=['get', 'put'], url_path='current', url_name='current')
|
||||
def get_current(self, request):
|
||||
"""Get or update current account's module enable settings"""
|
||||
if request.method == 'GET':
|
||||
return self.list(request)
|
||||
else:
|
||||
return self.update(request, pk=None)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""Get or create module enable settings for current account"""
|
||||
try:
|
||||
|
||||
@@ -16,8 +16,8 @@ router.register(r'strategies', StrategyViewSet, basename='strategy')
|
||||
router.register(r'settings/system', SystemSettingsViewSet, basename='system-settings')
|
||||
router.register(r'settings/account', AccountSettingsViewSet, basename='account-settings')
|
||||
router.register(r'settings/user', UserSettingsViewSet, basename='user-settings')
|
||||
# Register ModuleSettingsViewSet first
|
||||
router.register(r'settings/modules', ModuleSettingsViewSet, basename='module-settings')
|
||||
router.register(r'settings/modules/enable', ModuleEnableSettingsViewSet, basename='module-enable-settings')
|
||||
router.register(r'settings/ai', AISettingsViewSet, basename='ai-settings')
|
||||
|
||||
# Custom URL patterns for integration settings - matching reference plugin structure
|
||||
@@ -50,7 +50,20 @@ integration_image_gen_settings_viewset = IntegrationSettingsViewSet.as_view({
|
||||
'get': 'get_image_generation_settings',
|
||||
})
|
||||
|
||||
# Custom view for module enable settings to avoid URL routing conflict with ModuleSettingsViewSet
|
||||
# This must be defined as a custom path BEFORE router.urls to ensure it matches first
|
||||
# The update method handles pk=None correctly, so we can use as_view
|
||||
module_enable_viewset = ModuleEnableSettingsViewSet.as_view({
|
||||
'get': 'list',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
})
|
||||
|
||||
urlpatterns = [
|
||||
# Module enable settings endpoint - MUST come before router.urls to avoid conflict
|
||||
# When /settings/modules/enable/ is called, it would match ModuleSettingsViewSet with pk='enable'
|
||||
# So we define it as a custom path first
|
||||
path('settings/modules/enable/', module_enable_viewset, name='module-enable-settings'),
|
||||
path('', include(router.urls)),
|
||||
# Public health check endpoint (API Standard v1.0 requirement)
|
||||
path('ping/', ping, name='system-ping'),
|
||||
|
||||
@@ -644,9 +644,12 @@ class KeywordViewSet(SiteSectorModelViewSet):
|
||||
"data": {
|
||||
"user": { ... },
|
||||
"access": "eyJ0eXAiOiJKV1QiLCJhbGc...",
|
||||
"refresh": "eyJ0eXAiOiJKV1QiLCJhbGc..."
|
||||
"refresh": "eyJ0eXAiOiJKV1QiLCJhbGc...",
|
||||
"access_expires_at": "2025-01-XXT...",
|
||||
"refresh_expires_at": "2025-01-XXT..."
|
||||
},
|
||||
"message": "Login successful"
|
||||
"message": "Login successful",
|
||||
"request_id": "550e8400-e29b-41d4-a716-446655440000"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -194,13 +194,14 @@ export async function fetchAPI(endpoint: string, options?: RequestInit & { timeo
|
||||
|
||||
if (refreshResponse.ok) {
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData.success && refreshData.access) {
|
||||
const accessToken = refreshData.data?.access || refreshData.access;
|
||||
if (refreshData.success && accessToken) {
|
||||
// Update token in store
|
||||
try {
|
||||
const authStorage = localStorage.getItem('auth-storage');
|
||||
if (authStorage) {
|
||||
const parsed = JSON.parse(authStorage);
|
||||
parsed.state.token = refreshData.access;
|
||||
parsed.state.token = accessToken;
|
||||
localStorage.setItem('auth-storage', JSON.stringify(parsed));
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -210,7 +211,7 @@ export async function fetchAPI(endpoint: string, options?: RequestInit & { timeo
|
||||
// Retry original request with new token
|
||||
const newHeaders = {
|
||||
...headers,
|
||||
'Authorization': `Bearer ${refreshData.access}`,
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
};
|
||||
|
||||
const retryResponse = await fetch(`${API_BASE_URL}${endpoint}`, {
|
||||
|
||||
@@ -60,14 +60,17 @@ export const useAuthStore = create<AuthState>()(
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok || !data.success) {
|
||||
throw new Error(data.message || 'Login failed');
|
||||
throw new Error(data.error || data.message || 'Login failed');
|
||||
}
|
||||
|
||||
// Store user and JWT tokens
|
||||
// Store user and JWT tokens (handle both old and new API formats)
|
||||
const responseData = data.data || data;
|
||||
// Support both formats: new (access/refresh at top level) and old (tokens.access/refresh)
|
||||
const tokens = responseData.tokens || {};
|
||||
set({
|
||||
user: data.user,
|
||||
token: data.tokens?.access || null,
|
||||
refreshToken: data.tokens?.refresh || null,
|
||||
user: responseData.user || data.user,
|
||||
token: responseData.access || tokens.access || data.access || null,
|
||||
refreshToken: responseData.refresh || tokens.refresh || data.refresh || null,
|
||||
isAuthenticated: true,
|
||||
loading: false
|
||||
});
|
||||
@@ -119,8 +122,8 @@ export const useAuthStore = create<AuthState>()(
|
||||
// Store user and JWT tokens
|
||||
set({
|
||||
user: data.user,
|
||||
token: data.tokens?.access || null,
|
||||
refreshToken: data.tokens?.refresh || null,
|
||||
token: data.data?.access || data.access || null,
|
||||
refreshToken: data.data?.refresh || data.refresh || null,
|
||||
isAuthenticated: true,
|
||||
loading: false
|
||||
});
|
||||
@@ -168,8 +171,8 @@ export const useAuthStore = create<AuthState>()(
|
||||
throw new Error(data.message || 'Token refresh failed');
|
||||
}
|
||||
|
||||
// Update access token
|
||||
set({ token: data.access });
|
||||
// Update access token (API returns access at top level of data)
|
||||
set({ token: data.data?.access || data.access });
|
||||
|
||||
// Also refresh user data to get latest account/plan information
|
||||
// This ensures account/plan changes are reflected immediately
|
||||
|
||||
Reference in New Issue
Block a user