2244 lines
90 KiB
Python
2244 lines
90 KiB
Python
"""
|
|
Authentication Views - Structured as: Groups, Users, Accounts, Subscriptions, Site User Access
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
from rest_framework import viewsets, status, permissions, filters
|
|
from rest_framework.decorators import action
|
|
from rest_framework.response import Response
|
|
from rest_framework.views import APIView
|
|
from django.contrib.auth import authenticate
|
|
from django.utils import timezone
|
|
from django.db import transaction
|
|
from django_filters.rest_framework import DjangoFilterBackend
|
|
from drf_spectacular.utils import extend_schema, extend_schema_view
|
|
from igny8_core.api.base import AccountModelViewSet
|
|
from igny8_core.api.authentication import JWTAuthentication, CSRFExemptSessionAuthentication
|
|
from igny8_core.api.response import success_response, error_response
|
|
from igny8_core.api.throttles import DebugScopedRateThrottle
|
|
from igny8_core.api.pagination import CustomPageNumberPagination, LargeTablePagination
|
|
from igny8_core.api.permissions import IsAuthenticatedAndActive, HasTenantAccess
|
|
from .models import User, Account, Plan, Subscription, Site, Sector, SiteUserAccess, Industry, IndustrySector, SeedKeyword
|
|
from .serializers import (
|
|
UserSerializer, AccountSerializer, PlanSerializer, SubscriptionSerializer,
|
|
RegisterSerializer, LoginSerializer, ChangePasswordSerializer,
|
|
SiteSerializer, SectorSerializer, SiteUserAccessSerializer,
|
|
IndustrySerializer, IndustrySectorSerializer, SeedKeywordSerializer,
|
|
RefreshTokenSerializer, RequestPasswordResetSerializer, ResetPasswordSerializer
|
|
)
|
|
from .permissions import IsOwnerOrAdmin, IsEditorOrAbove, IsViewerOrAbove
|
|
from .utils import generate_access_token, generate_refresh_token, get_token_expiry, decode_token
|
|
from .models import PasswordResetToken
|
|
import jwt
|
|
|
|
|
|
# ============================================================================
|
|
# 1. GROUPS - Define user roles and permissions across the system
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
)
|
|
class GroupsViewSet(viewsets.ViewSet):
|
|
"""
|
|
ViewSet for managing user roles and permissions (Groups).
|
|
Groups are defined by the User.ROLE_CHOICES.
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
permission_classes = [IsOwnerOrAdmin]
|
|
throttle_scope = 'auth'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
def list(self, request):
|
|
"""List all available roles/groups."""
|
|
roles = [
|
|
{
|
|
'id': 'developer',
|
|
'name': 'Developer / Super Admin',
|
|
'description': 'Full access across all accounts (bypasses all filters)',
|
|
'permissions': ['full_access', 'bypass_filters', 'all_modules']
|
|
},
|
|
{
|
|
'id': 'owner',
|
|
'name': 'Owner',
|
|
'description': 'Full account access, billing, automation',
|
|
'permissions': ['account_management', 'billing', 'automation', 'all_sites']
|
|
},
|
|
{
|
|
'id': 'admin',
|
|
'name': 'Admin',
|
|
'description': 'Manage content modules, view billing (no edit)',
|
|
'permissions': ['content_management', 'view_billing', 'all_sites']
|
|
},
|
|
{
|
|
'id': 'editor',
|
|
'name': 'Editor',
|
|
'description': 'Generate AI content, manage clusters/tasks',
|
|
'permissions': ['ai_content', 'manage_clusters', 'manage_tasks', 'assigned_sites']
|
|
},
|
|
{
|
|
'id': 'viewer',
|
|
'name': 'Viewer',
|
|
'description': 'Read-only dashboards',
|
|
'permissions': ['read_only', 'assigned_sites']
|
|
},
|
|
{
|
|
'id': 'system_bot',
|
|
'name': 'System Bot',
|
|
'description': 'System automation user',
|
|
'permissions': ['automation_only']
|
|
}
|
|
]
|
|
return success_response(data={'groups': roles}, request=request)
|
|
|
|
@action(detail=False, methods=['get'], url_path='permissions')
|
|
def permissions(self, request):
|
|
"""Get permissions for a specific role."""
|
|
role = request.query_params.get('role')
|
|
if not role:
|
|
return error_response(
|
|
error='role parameter is required',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
role_permissions = {
|
|
'developer': ['full_access', 'bypass_filters', 'all_modules', 'all_accounts'],
|
|
'owner': ['account_management', 'billing', 'automation', 'all_sites', 'user_management'],
|
|
'admin': ['content_management', 'view_billing', 'all_sites', 'user_management'],
|
|
'editor': ['ai_content', 'manage_clusters', 'manage_tasks', 'assigned_sites'],
|
|
'viewer': ['read_only', 'assigned_sites'],
|
|
'system_bot': ['automation_only']
|
|
}
|
|
|
|
permissions_list = role_permissions.get(role, [])
|
|
return success_response(
|
|
data={
|
|
'role': role,
|
|
'permissions': permissions_list
|
|
},
|
|
request=request
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# 2. USERS - Manage global user records and credentials
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
create=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
update=extend_schema(tags=['Authentication']),
|
|
partial_update=extend_schema(tags=['Authentication']),
|
|
destroy=extend_schema(tags=['Authentication']),
|
|
)
|
|
class UsersViewSet(AccountModelViewSet):
|
|
"""
|
|
ViewSet for managing global user records and credentials.
|
|
Users are global, but belong to accounts.
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
queryset = User.objects.all()
|
|
serializer_class = UserSerializer
|
|
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsOwnerOrAdmin]
|
|
pagination_class = CustomPageNumberPagination
|
|
throttle_scope = 'auth'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
def get_queryset(self):
|
|
"""Return users based on access level."""
|
|
user = self.request.user
|
|
if not user or not user.is_authenticated:
|
|
return User.objects.none()
|
|
|
|
# Developers can see all users
|
|
if user.is_developer():
|
|
return User.objects.all()
|
|
|
|
# Owners/Admins can see users in their account
|
|
if user.role in ['owner', 'admin'] and user.account:
|
|
return User.objects.filter(account=user.account)
|
|
|
|
# Others can only see themselves
|
|
return User.objects.filter(id=user.id)
|
|
|
|
@action(detail=False, methods=['post'])
|
|
def create_user(self, request):
|
|
"""Create a new user (separate from registration)."""
|
|
from django.contrib.auth.password_validation import validate_password
|
|
|
|
email = request.data.get('email')
|
|
username = request.data.get('username')
|
|
password = request.data.get('password')
|
|
role = request.data.get('role', 'viewer')
|
|
account_id = request.data.get('account_id')
|
|
|
|
if not email or not username or not password:
|
|
return error_response(
|
|
error='email, username, and password are required',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Validate password
|
|
try:
|
|
validate_password(password)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Get account
|
|
account = None
|
|
if account_id:
|
|
try:
|
|
account = Account.objects.get(id=account_id)
|
|
except Account.DoesNotExist:
|
|
return error_response(
|
|
error=f'Account with id {account_id} does not exist',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
else:
|
|
# Use current user's account
|
|
if request.user.account:
|
|
account = request.user.account
|
|
|
|
# Create user
|
|
try:
|
|
user = User.objects.create_user(
|
|
username=username,
|
|
email=email,
|
|
password=password,
|
|
role=role,
|
|
account=account
|
|
)
|
|
serializer = UserSerializer(user)
|
|
return success_response(
|
|
data={'user': serializer.data},
|
|
status_code=status.HTTP_201_CREATED,
|
|
request=request
|
|
)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=True, methods=['post'])
|
|
def update_role(self, request, pk=None):
|
|
"""Update user role."""
|
|
user = self.get_object()
|
|
new_role = request.data.get('role')
|
|
|
|
if not new_role:
|
|
return error_response(
|
|
error='role is required',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
if new_role not in [choice[0] for choice in User.ROLE_CHOICES]:
|
|
return error_response(
|
|
error=f'Invalid role. Must be one of: {[c[0] for c in User.ROLE_CHOICES]}',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
user.role = new_role
|
|
user.save()
|
|
|
|
serializer = UserSerializer(user)
|
|
return success_response(data={'user': serializer.data}, request=request)
|
|
|
|
@action(detail=False, methods=['get', 'patch'], permission_classes=[IsAuthenticatedAndActive])
|
|
def me(self, request):
|
|
"""Get or update the current user profile."""
|
|
user = request.user
|
|
|
|
if request.method == 'PATCH':
|
|
serializer = UserSerializer(user, data=request.data, partial=True)
|
|
if not serializer.is_valid():
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
serializer.save()
|
|
|
|
serializer = UserSerializer(user)
|
|
return success_response(data={'user': serializer.data}, request=request)
|
|
|
|
|
|
# ============================================================================
|
|
# 3. ACCOUNTS - Register each unique organization/user space
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
create=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
update=extend_schema(tags=['Authentication']),
|
|
partial_update=extend_schema(tags=['Authentication']),
|
|
destroy=extend_schema(tags=['Authentication']),
|
|
)
|
|
class AccountsViewSet(AccountModelViewSet):
|
|
"""
|
|
ViewSet for managing accounts (unique organization/user spaces).
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
queryset = Account.objects.all()
|
|
serializer_class = AccountSerializer
|
|
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsOwnerOrAdmin]
|
|
pagination_class = CustomPageNumberPagination
|
|
throttle_scope = 'auth'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
def get_queryset(self):
|
|
"""Return accounts based on access level."""
|
|
user = self.request.user
|
|
if not user or not user.is_authenticated:
|
|
return Account.objects.none()
|
|
|
|
# Developers can see all accounts
|
|
if user.is_developer():
|
|
return Account.objects.all()
|
|
|
|
# Owners can see their own accounts
|
|
if user.role == 'owner':
|
|
return Account.objects.filter(owner=user)
|
|
|
|
# Admins can see their account
|
|
if user.role == 'admin' and user.account:
|
|
return Account.objects.filter(id=user.account.id)
|
|
|
|
return Account.objects.none()
|
|
|
|
def perform_create(self, serializer):
|
|
"""Create account with owner."""
|
|
user = self.request.user
|
|
|
|
# plan_id is mapped to plan in serializer (source='plan')
|
|
plan = serializer.validated_data.get('plan')
|
|
|
|
if not plan:
|
|
from rest_framework.exceptions import ValidationError
|
|
raise ValidationError("plan_id is required")
|
|
|
|
# Set owner to current user if not provided
|
|
owner = serializer.validated_data.get('owner')
|
|
if not owner:
|
|
owner = user
|
|
|
|
account = serializer.save(plan=plan, owner=owner)
|
|
return account
|
|
|
|
|
|
|
|
# ============================================================================
|
|
# 4. SUBSCRIPTIONS - Control plan level, limits, and billing per account
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
create=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
update=extend_schema(tags=['Authentication']),
|
|
partial_update=extend_schema(tags=['Authentication']),
|
|
destroy=extend_schema(tags=['Authentication']),
|
|
)
|
|
class SubscriptionsViewSet(AccountModelViewSet):
|
|
"""
|
|
ViewSet for managing subscriptions (plan level, limits, billing per account).
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
queryset = Subscription.objects.all()
|
|
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsOwnerOrAdmin]
|
|
pagination_class = CustomPageNumberPagination
|
|
# Use relaxed auth throttle to avoid 429s during onboarding plan fetches
|
|
throttle_scope = 'auth_read'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
def get_queryset(self):
|
|
"""Return subscriptions based on access level."""
|
|
user = self.request.user
|
|
if not user or not user.is_authenticated:
|
|
return Subscription.objects.none()
|
|
|
|
# Developers can see all subscriptions
|
|
if user.is_developer():
|
|
return Subscription.objects.all()
|
|
|
|
# Owners/Admins can see subscriptions for their account
|
|
if user.role in ['owner', 'admin'] and user.account:
|
|
return Subscription.objects.filter(account=user.account)
|
|
|
|
return Subscription.objects.none()
|
|
|
|
def get_serializer_class(self):
|
|
"""Return appropriate serializer."""
|
|
return SubscriptionSerializer
|
|
|
|
@action(detail=False, methods=['get'], url_path='by-account/(?P<account_id>[^/.]+)')
|
|
def by_account(self, request, account_id=None):
|
|
"""Get subscription for a specific account."""
|
|
try:
|
|
subscription = Subscription.objects.get(account_id=account_id)
|
|
serializer = self.get_serializer(subscription)
|
|
return success_response(
|
|
data={'subscription': serializer.data},
|
|
request=request
|
|
)
|
|
except Subscription.DoesNotExist:
|
|
return error_response(
|
|
error='Subscription not found for this account',
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# 5. SITE USER ACCESS - Assign users access to specific sites within account
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
create=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
update=extend_schema(tags=['Authentication']),
|
|
partial_update=extend_schema(tags=['Authentication']),
|
|
destroy=extend_schema(tags=['Authentication']),
|
|
)
|
|
class SiteUserAccessViewSet(AccountModelViewSet):
|
|
"""
|
|
ViewSet for managing Site-User access permissions.
|
|
Assign users access to specific sites within their account.
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
serializer_class = SiteUserAccessSerializer
|
|
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsOwnerOrAdmin]
|
|
pagination_class = CustomPageNumberPagination
|
|
throttle_scope = 'auth'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
def get_queryset(self):
|
|
"""Return access records for sites in user's account."""
|
|
user = self.request.user
|
|
if not user or not user.is_authenticated:
|
|
return SiteUserAccess.objects.none()
|
|
|
|
# Developers can see all access records
|
|
if user.is_developer():
|
|
return SiteUserAccess.objects.all()
|
|
|
|
if not user.account:
|
|
return SiteUserAccess.objects.none()
|
|
|
|
# Return access records for sites in user's account
|
|
return SiteUserAccess.objects.filter(site__account=user.account)
|
|
|
|
def perform_create(self, serializer):
|
|
"""Create site user access with granted_by."""
|
|
user = self.request.user
|
|
serializer.save(granted_by=user)
|
|
|
|
|
|
# ============================================================================
|
|
# SUPPORTING VIEWSETS (Sites, Sectors, Industries, Plans, Auth)
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
)
|
|
class PlanViewSet(viewsets.ReadOnlyModelViewSet):
|
|
"""
|
|
ViewSet for listing active subscription plans.
|
|
Excludes internal-only plans (Free/Internal) from public listings.
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
queryset = Plan.objects.filter(is_active=True, is_internal=False)
|
|
serializer_class = PlanSerializer
|
|
permission_classes = [permissions.AllowAny]
|
|
pagination_class = CustomPageNumberPagination
|
|
# Plans are public and should not throttle aggressively to avoid blocking signup/onboarding
|
|
throttle_scope = None
|
|
throttle_classes: list = []
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
"""Override list to return paginated response with unified format"""
|
|
queryset = self.filter_queryset(self.get_queryset())
|
|
page = self.paginate_queryset(queryset)
|
|
if page is not None:
|
|
serializer = self.get_serializer(page, many=True)
|
|
return self.get_paginated_response(serializer.data)
|
|
serializer = self.get_serializer(queryset, many=True)
|
|
return success_response(data={'results': serializer.data}, request=request)
|
|
|
|
def retrieve(self, request, *args, **kwargs):
|
|
"""Override retrieve to return unified format"""
|
|
try:
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(instance)
|
|
return success_response(data=serializer.data, request=request)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
|
|
class _IsOwnerOnly(permissions.BasePermission):
|
|
"""Only owner or developer can perform this action (e.g., create sites)."""
|
|
def has_permission(self, request, view):
|
|
user = getattr(request, 'user', None)
|
|
if not user or not user.is_authenticated:
|
|
return False
|
|
if getattr(user, 'is_superuser', False):
|
|
return True
|
|
return getattr(user, 'role', '') in ['owner', 'developer']
|
|
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
create=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
update=extend_schema(tags=['Authentication']),
|
|
partial_update=extend_schema(tags=['Authentication']),
|
|
destroy=extend_schema(tags=['Authentication']),
|
|
)
|
|
class SiteViewSet(AccountModelViewSet):
|
|
"""ViewSet for managing Sites."""
|
|
serializer_class = SiteSerializer
|
|
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsEditorOrAbove]
|
|
authentication_classes = [JWTAuthentication]
|
|
|
|
def get_permissions(self):
|
|
"""Viewers can list/retrieve sites; creation restricted to owner; writes require editor+."""
|
|
# Allow public read access for list requests with slug filter (used by Sites Renderer)
|
|
if self.action == 'list' and self.request.query_params.get('slug'):
|
|
from rest_framework.permissions import AllowAny
|
|
return [AllowAny()]
|
|
if self.action == 'create':
|
|
# Only owners and developers can create new sites (admin cannot)
|
|
return [permissions.IsAuthenticated(), _IsOwnerOnly()]
|
|
if self.action in ['list', 'retrieve']:
|
|
return [IsAuthenticatedAndActive(), HasTenantAccess(), IsViewerOrAbove()]
|
|
return [IsAuthenticatedAndActive(), HasTenantAccess(), IsEditorOrAbove()]
|
|
|
|
def get_queryset(self):
|
|
"""Return sites accessible to the current user."""
|
|
# If this is a public request (no auth) with slug filter, return site by slug
|
|
if not self.request.user or not self.request.user.is_authenticated:
|
|
slug = self.request.query_params.get('slug')
|
|
if slug:
|
|
# Return queryset directly from model (bypassing base class account filtering)
|
|
return Site.objects.filter(slug=slug, is_active=True)
|
|
return Site.objects.none()
|
|
|
|
user = self.request.user
|
|
|
|
account = getattr(user, 'account', None)
|
|
if not account:
|
|
return Site.objects.none()
|
|
|
|
if hasattr(user, 'get_accessible_sites'):
|
|
return user.get_accessible_sites()
|
|
|
|
return Site.objects.filter(account=account)
|
|
|
|
def perform_create(self, serializer):
|
|
"""Create site with account and auto-grant access to creator."""
|
|
account = getattr(self.request, 'account', None)
|
|
if not account:
|
|
user = self.request.user
|
|
if user and user.is_authenticated:
|
|
account = getattr(user, 'account', None)
|
|
|
|
# Check hard limit for sites
|
|
from igny8_core.business.billing.services.limit_service import LimitService, HardLimitExceededError
|
|
try:
|
|
LimitService.check_hard_limit(account, 'sites', additional_count=1)
|
|
except HardLimitExceededError as e:
|
|
from rest_framework.exceptions import PermissionDenied
|
|
raise PermissionDenied(str(e))
|
|
|
|
# Multiple sites can be active simultaneously - no constraint
|
|
site = serializer.save(account=account)
|
|
|
|
# Auto-create SiteUserAccess for owner/admin who creates the site
|
|
user = self.request.user
|
|
if user and user.is_authenticated and hasattr(user, 'role'):
|
|
if user.role in ['owner', 'admin']:
|
|
from igny8_core.auth.models import SiteUserAccess
|
|
SiteUserAccess.objects.get_or_create(
|
|
user=user,
|
|
site=site,
|
|
defaults={'granted_by': user}
|
|
)
|
|
|
|
def perform_update(self, serializer):
|
|
"""Update site."""
|
|
account = getattr(self.request, 'account', None)
|
|
if not account:
|
|
account = getattr(serializer.instance, 'account', None)
|
|
|
|
# Multiple sites can be active simultaneously - no constraint
|
|
serializer.save()
|
|
|
|
@action(detail=True, methods=['get'])
|
|
def sectors(self, request, pk=None):
|
|
"""Get all sectors for this site."""
|
|
site = self.get_object()
|
|
sectors = site.sectors.filter(is_active=True)
|
|
serializer = SectorSerializer(sectors, many=True)
|
|
return success_response(
|
|
data=serializer.data,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=True, methods=['post'], url_path='set_active')
|
|
def set_active(self, request, pk=None):
|
|
"""Set this site as active (multiple sites can be active simultaneously)."""
|
|
site = self.get_object()
|
|
|
|
# Simply activate this site - no need to deactivate others
|
|
site.is_active = True
|
|
site.status = 'active'
|
|
site.save()
|
|
|
|
serializer = self.get_serializer(site)
|
|
return success_response(
|
|
data={'site': serializer.data},
|
|
message=f'Site "{site.name}" is now active',
|
|
request=request
|
|
)
|
|
|
|
@action(detail=True, methods=['post'], url_path='select_sectors')
|
|
def select_sectors(self, request, pk=None):
|
|
"""Select industry and sectors for this site."""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
site = self.get_object()
|
|
except Exception as e:
|
|
logger.error(f"Error getting site object: {str(e)}", exc_info=True)
|
|
return error_response(
|
|
error=f'Site not found: {str(e)}',
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
sector_slugs = request.data.get('sector_slugs', [])
|
|
industry_slug = request.data.get('industry_slug')
|
|
|
|
if not industry_slug:
|
|
return error_response(
|
|
error='Industry slug is required',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
try:
|
|
industry = Industry.objects.get(slug=industry_slug, is_active=True)
|
|
except Industry.DoesNotExist:
|
|
return error_response(
|
|
error=f'Industry with slug "{industry_slug}" not found',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
site.industry = industry
|
|
site.save()
|
|
|
|
if not sector_slugs:
|
|
return success_response(
|
|
data={
|
|
'site': SiteSerializer(site).data,
|
|
'sectors': []
|
|
},
|
|
message=f'Industry "{industry.name}" set for site. No sectors selected.',
|
|
request=request
|
|
)
|
|
|
|
# Get plan's max_industries limit (if set), otherwise default to 5
|
|
max_sectors = site.get_max_sectors_limit()
|
|
|
|
if len(sector_slugs) > max_sectors:
|
|
return error_response(
|
|
error=f'Maximum {max_sectors} sectors allowed per site for this plan',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
created_sectors = []
|
|
updated_sectors = []
|
|
|
|
existing_sector_slugs = set(sector_slugs)
|
|
site.sectors.exclude(slug__in=existing_sector_slugs).update(is_active=False)
|
|
|
|
industry_sectors_map = {}
|
|
for sector_slug in sector_slugs:
|
|
industry_sector = IndustrySector.objects.filter(
|
|
industry=industry,
|
|
slug=sector_slug,
|
|
is_active=True
|
|
).first()
|
|
|
|
if not industry_sector:
|
|
return error_response(
|
|
error=f'Sector "{sector_slug}" not found in industry "{industry.name}"',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
industry_sectors_map[sector_slug] = industry_sector
|
|
|
|
for sector_slug, industry_sector in industry_sectors_map.items():
|
|
try:
|
|
# Check if site has account before proceeding
|
|
if not site.account:
|
|
logger.error(f"Site {site.id} has no account assigned")
|
|
return error_response(
|
|
error=f'Site "{site.name}" has no account assigned. Please contact support.',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
# Create or get sector - account will be set automatically in save() method
|
|
# But we need to pass it in defaults for get_or_create to work
|
|
sector, created = Sector.objects.get_or_create(
|
|
site=site,
|
|
slug=sector_slug,
|
|
defaults={
|
|
'industry_sector': industry_sector,
|
|
'name': industry_sector.name,
|
|
'description': industry_sector.description or '',
|
|
'is_active': True,
|
|
'status': 'active',
|
|
'account': site.account # Pass the account object, not the ID
|
|
}
|
|
)
|
|
|
|
if not created:
|
|
# Update existing sector
|
|
sector.industry_sector = industry_sector
|
|
sector.name = industry_sector.name
|
|
sector.description = industry_sector.description or ''
|
|
sector.is_active = True
|
|
sector.status = 'active'
|
|
# Ensure account is set (save() will also set it, but be explicit)
|
|
if not sector.account:
|
|
sector.account = site.account
|
|
sector.save()
|
|
updated_sectors.append(sector)
|
|
else:
|
|
created_sectors.append(sector)
|
|
except Exception as e:
|
|
logger.error(f"Error creating/updating sector {sector_slug}: {str(e)}", exc_info=True)
|
|
return error_response(
|
|
error=f'Failed to create/update sector "{sector_slug}": {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
# Get plan's max_industries limit (if set), otherwise default to 5
|
|
max_sectors = site.get_max_sectors_limit()
|
|
|
|
if site.get_active_sectors_count() > max_sectors:
|
|
return error_response(
|
|
error=f'Maximum {max_sectors} sectors allowed per site for this plan',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
serializer = SectorSerializer(site.sectors.filter(is_active=True), many=True)
|
|
return success_response(
|
|
data={
|
|
'created_count': len(created_sectors),
|
|
'updated_count': len(updated_sectors),
|
|
'sectors': serializer.data,
|
|
'site': SiteSerializer(site).data
|
|
},
|
|
message=f'Selected {len(sector_slugs)} sectors from industry "{industry.name}".',
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
create=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
update=extend_schema(tags=['Authentication']),
|
|
partial_update=extend_schema(tags=['Authentication']),
|
|
destroy=extend_schema(tags=['Authentication']),
|
|
)
|
|
class SectorViewSet(AccountModelViewSet):
|
|
"""ViewSet for managing Sectors."""
|
|
serializer_class = SectorSerializer
|
|
permission_classes = [IsAuthenticatedAndActive, HasTenantAccess, IsEditorOrAbove]
|
|
authentication_classes = [JWTAuthentication]
|
|
|
|
def get_permissions(self):
|
|
"""Viewers can list/retrieve sectors; writes require editor+."""
|
|
if self.action in ['list', 'retrieve']:
|
|
return [IsAuthenticatedAndActive(), HasTenantAccess(), IsViewerOrAbove()]
|
|
return [IsAuthenticatedAndActive(), HasTenantAccess(), IsEditorOrAbove()]
|
|
|
|
def get_queryset(self):
|
|
"""Return sectors from sites accessible to the current user."""
|
|
user = self.request.user
|
|
if not user or not user.is_authenticated:
|
|
return Sector.objects.none()
|
|
accessible_sites = user.get_accessible_sites()
|
|
return Sector.objects.filter(site__in=accessible_sites)
|
|
|
|
def get_queryset_with_site_filter(self):
|
|
"""Get queryset, optionally filtered by site_id."""
|
|
queryset = self.get_queryset()
|
|
site_id = self.request.query_params.get('site_id')
|
|
if site_id:
|
|
queryset = queryset.filter(site_id=site_id)
|
|
return queryset
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
"""Override list to apply site filter."""
|
|
queryset = self.get_queryset_with_site_filter()
|
|
serializer = self.get_serializer(queryset, many=True)
|
|
return success_response(
|
|
data=serializer.data,
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
)
|
|
class IndustryViewSet(viewsets.ReadOnlyModelViewSet):
|
|
"""
|
|
ViewSet for industry templates.
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
queryset = Industry.objects.filter(is_active=True).prefetch_related('sectors')
|
|
serializer_class = IndustrySerializer
|
|
permission_classes = [permissions.AllowAny]
|
|
pagination_class = CustomPageNumberPagination
|
|
throttle_scope = 'auth'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
def list(self, request):
|
|
"""Get all industries with their sectors."""
|
|
industries = self.get_queryset()
|
|
serializer = self.get_serializer(industries, many=True)
|
|
return success_response(
|
|
data={'industries': serializer.data},
|
|
request=request
|
|
)
|
|
|
|
def retrieve(self, request, *args, **kwargs):
|
|
"""Override retrieve to return unified format"""
|
|
try:
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(instance)
|
|
return success_response(data=serializer.data, request=request)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
|
|
@extend_schema_view(
|
|
list=extend_schema(tags=['Authentication']),
|
|
retrieve=extend_schema(tags=['Authentication']),
|
|
)
|
|
class SeedKeywordViewSet(viewsets.ReadOnlyModelViewSet):
|
|
"""
|
|
ViewSet for SeedKeyword - Global reference data (read-only for non-admins).
|
|
Unified API Standard v1.0 compliant
|
|
|
|
Sorting and filtering is applied server-side to ALL records, then paginated.
|
|
This ensures operations like "sort by volume DESC" return the globally highest
|
|
volume keywords, not just the highest within the current page.
|
|
"""
|
|
queryset = SeedKeyword.objects.filter(is_active=True).select_related('industry', 'sector')
|
|
serializer_class = SeedKeywordSerializer
|
|
permission_classes = [permissions.AllowAny] # Read-only, allow any authenticated user
|
|
pagination_class = LargeTablePagination # Supports up to 500 records per page
|
|
throttle_scope = 'auth'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
filter_backends = [filters.SearchFilter, filters.OrderingFilter, DjangoFilterBackend]
|
|
search_fields = ['keyword']
|
|
ordering_fields = ['keyword', 'volume', 'difficulty', 'created_at']
|
|
ordering = ['keyword']
|
|
filterset_fields = ['industry', 'sector', 'country', 'is_active']
|
|
|
|
def retrieve(self, request, *args, **kwargs):
|
|
"""Override retrieve to return unified format"""
|
|
try:
|
|
instance = self.get_object()
|
|
serializer = self.get_serializer(instance)
|
|
return success_response(data=serializer.data, request=request)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=str(e),
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
def get_queryset(self):
|
|
"""Filter by industry, sector, and difficulty range if provided."""
|
|
queryset = super().get_queryset()
|
|
industry_id = self.request.query_params.get('industry_id')
|
|
industry_name = self.request.query_params.get('industry_name')
|
|
sector_id = self.request.query_params.get('sector_id')
|
|
sector_ids = self.request.query_params.get('sector_ids') # Comma-separated list
|
|
sector_name = self.request.query_params.get('sector_name')
|
|
difficulty_min = self.request.query_params.get('difficulty_min')
|
|
difficulty_max = self.request.query_params.get('difficulty_max')
|
|
volume_min = self.request.query_params.get('volume_min')
|
|
volume_max = self.request.query_params.get('volume_max')
|
|
site_id = self.request.query_params.get('site_id')
|
|
available_only = self.request.query_params.get('available_only')
|
|
min_words = self.request.query_params.get('min_words')
|
|
|
|
if industry_id:
|
|
queryset = queryset.filter(industry_id=industry_id)
|
|
if industry_name:
|
|
queryset = queryset.filter(industry__name__icontains=industry_name)
|
|
|
|
# Support single sector_id OR multiple sector_ids (comma-separated)
|
|
if sector_id:
|
|
queryset = queryset.filter(sector_id=sector_id)
|
|
elif sector_ids:
|
|
try:
|
|
ids_list = [int(s.strip()) for s in sector_ids.split(',') if s.strip()]
|
|
if ids_list:
|
|
queryset = queryset.filter(sector_id__in=ids_list)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
if sector_name:
|
|
queryset = queryset.filter(sector__name__icontains=sector_name)
|
|
|
|
# Difficulty range filtering
|
|
if difficulty_min is not None:
|
|
try:
|
|
queryset = queryset.filter(difficulty__gte=int(difficulty_min))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if difficulty_max is not None:
|
|
try:
|
|
queryset = queryset.filter(difficulty__lte=int(difficulty_max))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Volume range filtering
|
|
if volume_min is not None:
|
|
try:
|
|
queryset = queryset.filter(volume__gte=int(volume_min))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if volume_max is not None:
|
|
try:
|
|
queryset = queryset.filter(volume__lte=int(volume_max))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Word count filtering (for long-tail keywords - 4+ words)
|
|
if min_words is not None:
|
|
try:
|
|
min_word_count = int(min_words)
|
|
if min_word_count == 4:
|
|
# Long-tail: 4+ words (keywords with at least 3 spaces)
|
|
queryset = queryset.filter(keyword__regex=r'^(\S+\s+){3,}\S+$')
|
|
elif min_word_count > 1:
|
|
# Generic word count filter using regex
|
|
pattern = r'^(\S+\s+){' + str(min_word_count - 1) + r',}\S+$'
|
|
queryset = queryset.filter(keyword__regex=pattern)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Availability filter - exclude keywords already added to the site
|
|
if available_only and str(available_only).lower() in ['true', '1', 'yes']:
|
|
if site_id:
|
|
try:
|
|
from igny8_core.business.planning.models import Keywords
|
|
attached_ids = Keywords.objects.filter(
|
|
site_id=site_id,
|
|
seed_keyword__isnull=False
|
|
).values_list('seed_keyword_id', flat=True)
|
|
queryset = queryset.exclude(id__in=attached_ids)
|
|
except Exception:
|
|
pass
|
|
|
|
return queryset
|
|
|
|
@action(detail=False, methods=['get'], url_path='stats', url_name='stats')
|
|
def stats(self, request):
|
|
"""
|
|
Get aggregated keyword statistics by industry and country.
|
|
Returns top industries and countries with keyword counts and total volume.
|
|
"""
|
|
from django.db.models import Count, Sum, Q
|
|
|
|
try:
|
|
# Top industries by keyword count
|
|
industries = Industry.objects.annotate(
|
|
keyword_count=Count('seed_keywords', filter=Q(seed_keywords__is_active=True)),
|
|
total_volume=Sum('seed_keywords__volume', filter=Q(seed_keywords__is_active=True))
|
|
).filter(
|
|
keyword_count__gt=0
|
|
).order_by('-keyword_count')[:10]
|
|
|
|
industries_data = [{
|
|
'name': ind.name,
|
|
'slug': ind.slug,
|
|
'keyword_count': ind.keyword_count or 0,
|
|
'total_volume': ind.total_volume or 0,
|
|
} for ind in industries]
|
|
|
|
# Keywords by country
|
|
countries = SeedKeyword.objects.filter(
|
|
is_active=True
|
|
).values('country').annotate(
|
|
keyword_count=Count('id'),
|
|
total_volume=Sum('volume')
|
|
).order_by('-keyword_count')
|
|
|
|
countries_data = [{
|
|
'country': c['country'],
|
|
'country_display': dict(SeedKeyword.COUNTRY_CHOICES).get(c['country'], c['country']),
|
|
'keyword_count': c['keyword_count'],
|
|
'total_volume': c['total_volume'] or 0,
|
|
} for c in countries]
|
|
|
|
# Total stats
|
|
total_stats = SeedKeyword.objects.filter(is_active=True).aggregate(
|
|
total_keywords=Count('id'),
|
|
total_volume=Sum('volume')
|
|
)
|
|
|
|
data = {
|
|
'industries': industries_data,
|
|
'countries': countries_data,
|
|
'total_keywords': total_stats['total_keywords'] or 0,
|
|
'total_volume': total_stats['total_volume'] or 0,
|
|
}
|
|
|
|
return success_response(data=data, request=request)
|
|
except Exception as e:
|
|
return error_response(
|
|
error=f'Failed to fetch keyword stats: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'], url_path='import_seed_keywords', url_name='import_seed_keywords')
|
|
def import_seed_keywords(self, request):
|
|
"""
|
|
Import seed keywords from CSV (Admin/Superuser only).
|
|
Expected columns: keyword, industry_name, sector_name, volume, difficulty, country
|
|
"""
|
|
import csv
|
|
from django.db import transaction
|
|
|
|
# Check admin/superuser permission
|
|
if not (request.user.is_staff or request.user.is_superuser):
|
|
return error_response(
|
|
error='Admin or superuser access required',
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
request=request
|
|
)
|
|
|
|
if 'file' not in request.FILES:
|
|
return error_response(
|
|
error='No file provided',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
file = request.FILES['file']
|
|
if not file.name.endswith('.csv'):
|
|
return error_response(
|
|
error='File must be a CSV',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
try:
|
|
# Parse CSV
|
|
decoded_file = file.read().decode('utf-8')
|
|
csv_reader = csv.DictReader(decoded_file.splitlines())
|
|
|
|
imported_count = 0
|
|
skipped_count = 0
|
|
errors = []
|
|
|
|
with transaction.atomic():
|
|
for row_num, row in enumerate(csv_reader, start=2): # Start at 2 (header is row 1)
|
|
try:
|
|
keyword_text = row.get('keyword', '').strip()
|
|
industry_name = row.get('industry_name', '').strip()
|
|
sector_name = row.get('sector_name', '').strip()
|
|
|
|
if not all([keyword_text, industry_name, sector_name]):
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# Get or create industry
|
|
industry = Industry.objects.filter(name=industry_name).first()
|
|
if not industry:
|
|
errors.append(f"Row {row_num}: Industry '{industry_name}' not found")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# Get or create industry sector
|
|
sector = IndustrySector.objects.filter(
|
|
industry=industry,
|
|
name=sector_name
|
|
).first()
|
|
if not sector:
|
|
errors.append(f"Row {row_num}: Sector '{sector_name}' not found for industry '{industry_name}'")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# Check if keyword already exists
|
|
existing = SeedKeyword.objects.filter(
|
|
keyword=keyword_text,
|
|
industry=industry,
|
|
sector=sector
|
|
).first()
|
|
|
|
if existing:
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# Create seed keyword
|
|
SeedKeyword.objects.create(
|
|
keyword=keyword_text,
|
|
industry=industry,
|
|
sector=sector,
|
|
volume=int(row.get('volume', 0) or 0),
|
|
difficulty=int(row.get('difficulty', 0) or 0),
|
|
country=row.get('country', 'US') or 'US',
|
|
is_active=True
|
|
)
|
|
imported_count += 1
|
|
|
|
except Exception as e:
|
|
errors.append(f"Row {row_num}: {str(e)}")
|
|
skipped_count += 1
|
|
|
|
return success_response(
|
|
data={
|
|
'imported': imported_count,
|
|
'skipped': skipped_count,
|
|
'errors': errors[:10] if errors else [] # Limit errors to first 10
|
|
},
|
|
message=f'Import completed: {imported_count} keywords imported, {skipped_count} skipped',
|
|
request=request
|
|
)
|
|
|
|
except Exception as e:
|
|
return error_response(
|
|
error=f'Failed to import keywords: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['get'], url_path='sector_stats', url_name='sector_stats')
|
|
def sector_stats(self, request):
|
|
"""
|
|
Get sector-level statistics for the Keywords Library dashboard.
|
|
Returns 6 stat types with dynamic fallback thresholds.
|
|
|
|
Stats:
|
|
- total: Total keywords in sector
|
|
- available: Keywords not yet added by user's site
|
|
- high_volume: Volume >= 10K (Premium Traffic)
|
|
- premium_traffic: Volume >= 50K with fallbacks (50K -> 25K -> 10K)
|
|
- long_tail: 4+ words with Volume > threshold (1K -> 500 -> 200)
|
|
- quick_wins: Difficulty <= 20, Volume > threshold, AND available
|
|
|
|
sector_ids: Comma-separated list of IndustrySector IDs to filter by (for site-specific filtering)
|
|
"""
|
|
from django.db.models import Count, Sum, Q, F
|
|
from django.db.models.functions import Length
|
|
|
|
try:
|
|
# Get filters
|
|
industry_id = request.query_params.get('industry_id')
|
|
sector_id = request.query_params.get('sector_id')
|
|
sector_ids = request.query_params.get('sector_ids') # Comma-separated list
|
|
site_id = request.query_params.get('site_id')
|
|
|
|
if not industry_id:
|
|
return error_response(
|
|
error='industry_id is required',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Base queryset for the industry
|
|
base_qs = SeedKeyword.objects.filter(
|
|
is_active=True,
|
|
industry_id=industry_id
|
|
)
|
|
|
|
if sector_id:
|
|
base_qs = base_qs.filter(sector_id=sector_id)
|
|
|
|
# Get already-added keyword IDs if site_id provided
|
|
already_added_ids = set()
|
|
if site_id:
|
|
from igny8_core.business.planning.models import Keywords
|
|
already_added_ids = set(
|
|
Keywords.objects.filter(
|
|
site_id=site_id,
|
|
seed_keyword__isnull=False
|
|
).values_list('seed_keyword_id', flat=True)
|
|
)
|
|
|
|
# Helper to count with availability filter
|
|
def count_available(qs):
|
|
if not site_id:
|
|
return qs.count()
|
|
return qs.exclude(id__in=already_added_ids).count()
|
|
|
|
# Helper for dynamic threshold fallback - returns both total and available
|
|
def get_count_with_fallback(qs, thresholds, volume_field='volume'):
|
|
"""Try thresholds in order, return first with results."""
|
|
for threshold in thresholds:
|
|
filtered = qs.filter(**{f'{volume_field}__gte': threshold})
|
|
total_count = filtered.count()
|
|
if total_count > 0:
|
|
available = count_available(filtered)
|
|
return {'count': total_count, 'available': available, 'threshold': threshold}
|
|
return {'count': 0, 'available': 0, 'threshold': thresholds[-1]}
|
|
|
|
# 1. Total keywords
|
|
total_count = base_qs.count()
|
|
|
|
# 2. Available keywords (not yet added)
|
|
available_count = count_available(base_qs)
|
|
|
|
# 3. High Volume (>= 10K) - simple threshold
|
|
high_volume_qs = base_qs.filter(volume__gte=10000)
|
|
high_volume_count = high_volume_qs.count()
|
|
high_volume_available = count_available(high_volume_qs)
|
|
|
|
# 3b. Mid Volume (5K-10K)
|
|
mid_volume_qs = base_qs.filter(volume__gte=5000, volume__lt=10000)
|
|
mid_volume_count = mid_volume_qs.count()
|
|
mid_volume_available = count_available(mid_volume_qs)
|
|
|
|
# 4. Premium Traffic with dynamic fallback (50K -> 25K -> 10K)
|
|
premium_thresholds = [50000, 25000, 10000]
|
|
premium_result = get_count_with_fallback(base_qs, premium_thresholds)
|
|
|
|
# 5. Long Tail: 4+ words AND volume > threshold (1K -> 500 -> 200)
|
|
# Count words by counting spaces + 1
|
|
long_tail_base = base_qs.annotate(
|
|
word_count=Length('keyword') - Length('keyword', output_field=None) + 1
|
|
)
|
|
# Simpler: filter keywords with 3+ spaces (4+ words)
|
|
long_tail_base = base_qs.filter(keyword__regex=r'^(\S+\s+){3,}\S+$')
|
|
long_tail_thresholds = [1000, 500, 200]
|
|
long_tail_result = get_count_with_fallback(long_tail_base, long_tail_thresholds)
|
|
|
|
# 6. Quick Wins: Difficulty <= 20 AND volume > threshold AND available
|
|
quick_wins_base = base_qs.filter(difficulty__lte=20)
|
|
if site_id:
|
|
quick_wins_base = quick_wins_base.exclude(id__in=already_added_ids)
|
|
quick_wins_thresholds = [1000, 500, 200]
|
|
quick_wins_result = get_count_with_fallback(quick_wins_base, quick_wins_thresholds)
|
|
|
|
# Build response per sector if no sector_id, or single stats if sector_id provided
|
|
if sector_id:
|
|
data = {
|
|
'sector_id': int(sector_id),
|
|
'stats': {
|
|
'total': {'count': total_count},
|
|
'available': {'count': available_count},
|
|
'high_volume': {'count': high_volume_count, 'available': high_volume_available, 'threshold': 10000},
|
|
'mid_volume': {'count': mid_volume_count, 'available': mid_volume_available, 'threshold': 5000},
|
|
'premium_traffic': premium_result,
|
|
'long_tail': long_tail_result,
|
|
'quick_wins': quick_wins_result,
|
|
}
|
|
}
|
|
else:
|
|
# Get stats per sector in the industry
|
|
# Filter by specific sector_ids if provided (for site-specific sectors)
|
|
sectors = IndustrySector.objects.filter(industry_id=industry_id)
|
|
if sector_ids:
|
|
try:
|
|
ids_list = [int(s.strip()) for s in sector_ids.split(',') if s.strip()]
|
|
if ids_list:
|
|
sectors = sectors.filter(id__in=ids_list)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
sectors_data = []
|
|
|
|
for sector in sectors:
|
|
sector_qs = base_qs.filter(sector=sector)
|
|
sector_total = sector_qs.count()
|
|
|
|
if sector_total == 0:
|
|
continue
|
|
|
|
sector_available = count_available(sector_qs)
|
|
|
|
# High volume with available count
|
|
sector_high_volume_qs = sector_qs.filter(volume__gte=10000)
|
|
sector_high_volume = sector_high_volume_qs.count()
|
|
sector_high_volume_available = count_available(sector_high_volume_qs)
|
|
|
|
# Mid volume with available count
|
|
sector_mid_volume_qs = sector_qs.filter(volume__gte=5000, volume__lt=10000)
|
|
sector_mid_volume = sector_mid_volume_qs.count()
|
|
sector_mid_volume_available = count_available(sector_mid_volume_qs)
|
|
|
|
sector_premium = get_count_with_fallback(sector_qs, premium_thresholds)
|
|
|
|
sector_long_tail_base = sector_qs.filter(keyword__regex=r'^(\S+\s+){3,}\S+$')
|
|
sector_long_tail = get_count_with_fallback(sector_long_tail_base, long_tail_thresholds)
|
|
|
|
sector_quick_wins_base = sector_qs.filter(difficulty__lte=20)
|
|
if site_id:
|
|
sector_quick_wins_base = sector_quick_wins_base.exclude(id__in=already_added_ids)
|
|
sector_quick_wins = get_count_with_fallback(sector_quick_wins_base, quick_wins_thresholds)
|
|
|
|
sectors_data.append({
|
|
'sector_id': sector.id,
|
|
'sector_name': sector.name,
|
|
'stats': {
|
|
'total': {'count': sector_total},
|
|
'available': {'count': sector_available},
|
|
'high_volume': {'count': sector_high_volume, 'available': sector_high_volume_available, 'threshold': 10000},
|
|
'mid_volume': {'count': sector_mid_volume, 'available': sector_mid_volume_available, 'threshold': 5000},
|
|
'premium_traffic': sector_premium,
|
|
'long_tail': sector_long_tail,
|
|
'quick_wins': sector_quick_wins,
|
|
}
|
|
})
|
|
|
|
data = {
|
|
'industry_id': int(industry_id),
|
|
'sectors': sectors_data,
|
|
}
|
|
|
|
return success_response(data=data, request=request)
|
|
|
|
except Exception as e:
|
|
return error_response(
|
|
error=f'Failed to fetch sector stats: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['get'], url_path='filter_options', url_name='filter_options')
|
|
def filter_options(self, request):
|
|
"""
|
|
Get cascading filter options for Keywords Library.
|
|
Returns industries, sectors (filtered by industry), and available filter values.
|
|
Supports cascading options based on current filters.
|
|
"""
|
|
from django.db.models import Count, Min, Max, Q, Value
|
|
from django.db.models.functions import Length, Replace
|
|
|
|
try:
|
|
industry_id = request.query_params.get('industry_id')
|
|
sector_id = request.query_params.get('sector_id')
|
|
country_filter = request.query_params.get('country')
|
|
difficulty_min = request.query_params.get('difficulty_min')
|
|
difficulty_max = request.query_params.get('difficulty_max')
|
|
volume_min = request.query_params.get('volume_min')
|
|
volume_max = request.query_params.get('volume_max')
|
|
search_term = request.query_params.get('search')
|
|
min_words = request.query_params.get('min_words')
|
|
site_id = request.query_params.get('site_id')
|
|
available_only = request.query_params.get('available_only') == 'true'
|
|
|
|
# Get industries with keyword counts
|
|
industries = Industry.objects.annotate(
|
|
keyword_count=Count('seed_keywords', filter=Q(seed_keywords__is_active=True))
|
|
).filter(keyword_count__gt=0).order_by('name')
|
|
|
|
industries_data = [{
|
|
'id': ind.id,
|
|
'name': ind.name,
|
|
'slug': ind.slug,
|
|
'keyword_count': ind.keyword_count,
|
|
} for ind in industries]
|
|
|
|
# Get sectors filtered by industry if provided
|
|
sectors_data = []
|
|
if industry_id:
|
|
sectors = IndustrySector.objects.filter(
|
|
industry_id=industry_id
|
|
).annotate(
|
|
keyword_count=Count('seed_keywords', filter=Q(seed_keywords__is_active=True))
|
|
).filter(keyword_count__gt=0).order_by('name')
|
|
|
|
sectors_data = [{
|
|
'id': sec.id,
|
|
'name': sec.name,
|
|
'slug': sec.slug,
|
|
'keyword_count': sec.keyword_count,
|
|
} for sec in sectors]
|
|
|
|
# Base queryset for cascading options
|
|
base_qs = SeedKeyword.objects.filter(is_active=True)
|
|
if industry_id:
|
|
base_qs = base_qs.filter(industry_id=industry_id)
|
|
if sector_id:
|
|
base_qs = base_qs.filter(sector_id=sector_id)
|
|
|
|
# Apply min_words filter (for long-tail keywords)
|
|
if min_words is not None:
|
|
try:
|
|
min_words_int = int(min_words)
|
|
from django.db.models.functions import Length
|
|
# Count words by counting spaces + 1
|
|
base_qs = base_qs.annotate(
|
|
word_count=Length('keyword') - Length(Replace('keyword', Value(' '), Value(''))) + 1
|
|
).filter(word_count__gte=min_words_int)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Apply available_only filter (exclude keywords already added to site)
|
|
if available_only and site_id:
|
|
try:
|
|
from igny8_core.business.planning.models import Keywords
|
|
site_id_int = int(site_id)
|
|
# Get seed keyword IDs already added to this site
|
|
existing_seed_ids = Keywords.objects.filter(
|
|
site_id=site_id_int,
|
|
seed_keyword__isnull=False
|
|
).values_list('seed_keyword_id', flat=True)
|
|
base_qs = base_qs.exclude(id__in=existing_seed_ids)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Countries options - apply all filters except country itself
|
|
countries_qs = base_qs
|
|
if difficulty_min is not None:
|
|
try:
|
|
countries_qs = countries_qs.filter(difficulty__gte=int(difficulty_min))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if difficulty_max is not None:
|
|
try:
|
|
countries_qs = countries_qs.filter(difficulty__lte=int(difficulty_max))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if volume_min is not None:
|
|
try:
|
|
countries_qs = countries_qs.filter(volume__gte=int(volume_min))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if volume_max is not None:
|
|
try:
|
|
countries_qs = countries_qs.filter(volume__lte=int(volume_max))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if search_term:
|
|
countries_qs = countries_qs.filter(keyword__icontains=search_term)
|
|
|
|
countries = countries_qs.values('country').annotate(
|
|
keyword_count=Count('id')
|
|
).order_by('country')
|
|
country_label_map = dict(SeedKeyword.COUNTRY_CHOICES)
|
|
countries_data = [{
|
|
'value': c['country'],
|
|
'label': country_label_map.get(c['country'], c['country']),
|
|
'keyword_count': c['keyword_count'],
|
|
} for c in countries if c['country']]
|
|
|
|
# Difficulty options - apply all filters except difficulty itself
|
|
difficulty_qs = base_qs
|
|
if country_filter:
|
|
difficulty_qs = difficulty_qs.filter(country=country_filter)
|
|
if volume_min is not None:
|
|
try:
|
|
difficulty_qs = difficulty_qs.filter(volume__gte=int(volume_min))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if volume_max is not None:
|
|
try:
|
|
difficulty_qs = difficulty_qs.filter(volume__lte=int(volume_max))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if search_term:
|
|
difficulty_qs = difficulty_qs.filter(keyword__icontains=search_term)
|
|
|
|
difficulty_ranges = [
|
|
(1, 'Very Easy', 0, 10),
|
|
(2, 'Easy', 11, 30),
|
|
(3, 'Medium', 31, 50),
|
|
(4, 'Hard', 51, 70),
|
|
(5, 'Very Hard', 71, 100),
|
|
]
|
|
|
|
difficulty_levels = []
|
|
for level, label, min_val, max_val in difficulty_ranges:
|
|
count = difficulty_qs.filter(
|
|
difficulty__gte=min_val,
|
|
difficulty__lte=max_val
|
|
).count()
|
|
if count > 0:
|
|
difficulty_levels.append({
|
|
'level': level,
|
|
'label': label,
|
|
'backend_range': [min_val, max_val],
|
|
'keyword_count': count,
|
|
})
|
|
|
|
# Difficulty range (filtered by current non-difficulty filters)
|
|
difficulty_range = difficulty_qs.aggregate(
|
|
min_difficulty=Min('difficulty'),
|
|
max_difficulty=Max('difficulty')
|
|
)
|
|
|
|
# Volume range (filtered by current non-volume filters)
|
|
volume_qs = base_qs
|
|
if country_filter:
|
|
volume_qs = volume_qs.filter(country=country_filter)
|
|
if difficulty_min is not None:
|
|
try:
|
|
volume_qs = volume_qs.filter(difficulty__gte=int(difficulty_min))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if difficulty_max is not None:
|
|
try:
|
|
volume_qs = volume_qs.filter(difficulty__lte=int(difficulty_max))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if search_term:
|
|
volume_qs = volume_qs.filter(keyword__icontains=search_term)
|
|
|
|
volume_range = volume_qs.aggregate(
|
|
min_volume=Min('volume'),
|
|
max_volume=Max('volume')
|
|
)
|
|
|
|
data = {
|
|
'industries': industries_data,
|
|
'sectors': sectors_data,
|
|
'countries': countries_data,
|
|
'difficulty': {
|
|
'range': difficulty_range,
|
|
'levels': difficulty_levels,
|
|
},
|
|
'volume': volume_range,
|
|
}
|
|
|
|
return success_response(data=data, request=request)
|
|
|
|
except Exception as e:
|
|
return error_response(
|
|
error=f'Failed to fetch filter options: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'], url_path='bulk_add', url_name='bulk_add')
|
|
def bulk_add(self, request):
|
|
"""
|
|
Bulk add keywords to a site from the Keywords Library.
|
|
Accepts a list of seed_keyword IDs and adds them to the specified site.
|
|
"""
|
|
from django.db import transaction
|
|
from igny8_core.business.planning.models import Keywords
|
|
|
|
try:
|
|
site_id = request.data.get('site_id')
|
|
keyword_ids = request.data.get('keyword_ids', [])
|
|
|
|
if not site_id:
|
|
return error_response(
|
|
error='site_id is required',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
if not keyword_ids or not isinstance(keyword_ids, list):
|
|
return error_response(
|
|
error='keyword_ids must be a non-empty list',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Verify site access
|
|
from igny8_core.auth.models import Site
|
|
site = Site.objects.filter(id=site_id).first()
|
|
if not site:
|
|
return error_response(
|
|
error='Site not found',
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
# Check user has access to this site
|
|
user = request.user
|
|
if not user.is_authenticated:
|
|
return error_response(
|
|
error='Authentication required',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
# Allow if user owns the site or is staff
|
|
if not (user.is_staff or site.account_id == getattr(user, 'account_id', None)):
|
|
return error_response(
|
|
error='Access denied to this site',
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
request=request
|
|
)
|
|
|
|
# Get seed keywords
|
|
seed_keywords = SeedKeyword.objects.filter(
|
|
id__in=keyword_ids,
|
|
is_active=True
|
|
)
|
|
|
|
# Get already existing
|
|
existing_seed_ids = set(
|
|
Keywords.objects.filter(
|
|
site_id=site_id,
|
|
seed_keyword_id__in=keyword_ids
|
|
).values_list('seed_keyword_id', flat=True)
|
|
)
|
|
|
|
# Get site sectors mapped by industry_sector_id for fast lookup
|
|
from igny8_core.auth.models import Sector
|
|
site_sectors = {
|
|
s.industry_sector_id: s
|
|
for s in Sector.objects.filter(site=site, is_deleted=False, is_active=True)
|
|
}
|
|
|
|
added_count = 0
|
|
skipped_count = 0
|
|
|
|
with transaction.atomic():
|
|
for seed_kw in seed_keywords:
|
|
if seed_kw.id in existing_seed_ids:
|
|
skipped_count += 1
|
|
continue
|
|
|
|
# Find the site's sector that matches this keyword's industry_sector
|
|
site_sector = site_sectors.get(seed_kw.sector_id)
|
|
if not site_sector:
|
|
# Skip if site doesn't have this sector
|
|
skipped_count += 1
|
|
continue
|
|
|
|
Keywords.objects.create(
|
|
site=site,
|
|
sector=site_sector,
|
|
seed_keyword=seed_kw,
|
|
)
|
|
added_count += 1
|
|
|
|
return success_response(
|
|
data={
|
|
'added': added_count,
|
|
'skipped': skipped_count,
|
|
'total_requested': len(keyword_ids),
|
|
},
|
|
message=f'Successfully added {added_count} keywords to your site',
|
|
request=request
|
|
)
|
|
|
|
except Exception as e:
|
|
return error_response(
|
|
error=f'Failed to bulk add keywords: {str(e)}',
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request=request
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# AUTHENTICATION ENDPOINTS (Register, Login, Change Password, Me)
|
|
# ============================================================================
|
|
|
|
@extend_schema_view(
|
|
register=extend_schema(tags=['Authentication']),
|
|
login=extend_schema(tags=['Authentication']),
|
|
change_password=extend_schema(tags=['Authentication']),
|
|
refresh_token=extend_schema(tags=['Authentication']),
|
|
)
|
|
class AuthViewSet(viewsets.GenericViewSet):
|
|
"""Authentication endpoints.
|
|
Unified API Standard v1.0 compliant
|
|
"""
|
|
permission_classes = [permissions.AllowAny]
|
|
throttle_scope = 'auth_strict'
|
|
throttle_classes = [DebugScopedRateThrottle]
|
|
|
|
@action(detail=False, methods=['post'])
|
|
def register(self, request):
|
|
"""User registration endpoint."""
|
|
serializer = RegisterSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
user = serializer.save()
|
|
|
|
# Log the user in (create session for session authentication)
|
|
from django.contrib.auth import login
|
|
login(request, user)
|
|
|
|
# Get account from user
|
|
account = getattr(user, 'account', None)
|
|
|
|
# Generate JWT tokens
|
|
access_token = generate_access_token(user, account)
|
|
refresh_token = generate_refresh_token(user, account)
|
|
access_expires_at = get_token_expiry('access')
|
|
refresh_expires_at = get_token_expiry('refresh')
|
|
|
|
user_serializer = UserSerializer(user)
|
|
return success_response(
|
|
data={
|
|
'user': user_serializer.data,
|
|
'tokens': {
|
|
'access': access_token,
|
|
'refresh': refresh_token,
|
|
'access_expires_at': access_expires_at.isoformat(),
|
|
'refresh_expires_at': refresh_expires_at.isoformat(),
|
|
}
|
|
},
|
|
message='Registration successful',
|
|
status_code=status.HTTP_201_CREATED,
|
|
request=request
|
|
)
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'])
|
|
def login(self, request):
|
|
"""User login endpoint."""
|
|
serializer = LoginSerializer(data=request.data)
|
|
if serializer.is_valid():
|
|
email = serializer.validated_data['email']
|
|
password = serializer.validated_data['password']
|
|
|
|
try:
|
|
user = User.objects.select_related('account', 'account__plan').get(email=email)
|
|
except User.DoesNotExist:
|
|
return error_response(
|
|
error='Invalid credentials',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
if user.check_password(password):
|
|
# Ensure user has an account
|
|
account = getattr(user, 'account', None)
|
|
if account is None:
|
|
return error_response(
|
|
error='Account not configured for this user. Please contact support.',
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
request=request,
|
|
)
|
|
|
|
# Ensure account has an active plan
|
|
plan = getattr(account, 'plan', None)
|
|
if plan is None or getattr(plan, 'is_active', False) is False:
|
|
return error_response(
|
|
error='Active subscription required. Visit igny8.com/pricing to subscribe.',
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
request=request,
|
|
)
|
|
|
|
# Log the user in (create session for session authentication)
|
|
from django.contrib.auth import login
|
|
login(request, user)
|
|
|
|
# Generate JWT tokens
|
|
access_token = generate_access_token(user, account)
|
|
refresh_token = generate_refresh_token(user, account)
|
|
access_expires_at = get_token_expiry('access')
|
|
refresh_expires_at = get_token_expiry('refresh')
|
|
|
|
user_serializer = UserSerializer(user)
|
|
return success_response(
|
|
data={
|
|
'user': user_serializer.data,
|
|
'access': access_token,
|
|
'refresh': refresh_token,
|
|
'access_expires_at': access_expires_at.isoformat(),
|
|
'refresh_expires_at': refresh_expires_at.isoformat(),
|
|
},
|
|
message='Login successful',
|
|
request=request
|
|
)
|
|
|
|
return error_response(
|
|
error='Invalid credentials',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'], permission_classes=[permissions.IsAuthenticated])
|
|
def change_password(self, request):
|
|
"""Change password endpoint."""
|
|
serializer = ChangePasswordSerializer(data=request.data, context={'request': request})
|
|
if serializer.is_valid():
|
|
user = request.user
|
|
if not user.check_password(serializer.validated_data['old_password']):
|
|
return error_response(
|
|
error='Current password is incorrect',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
user.set_password(serializer.validated_data['new_password'])
|
|
user.save()
|
|
|
|
return success_response(
|
|
message='Password changed successfully',
|
|
request=request
|
|
)
|
|
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['get'], permission_classes=[permissions.IsAuthenticated])
|
|
def me(self, request):
|
|
"""Get current user information."""
|
|
# Refresh user from DB to get latest account/plan data
|
|
# This ensures account/plan changes are reflected immediately
|
|
user = User.objects.select_related('account', 'account__plan').get(id=request.user.id)
|
|
serializer = UserSerializer(user)
|
|
return success_response(
|
|
data={'user': serializer.data},
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'], permission_classes=[permissions.AllowAny])
|
|
def refresh(self, request):
|
|
"""Refresh access token using refresh token."""
|
|
serializer = RefreshTokenSerializer(data=request.data)
|
|
if not serializer.is_valid():
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
refresh_token = serializer.validated_data['refresh']
|
|
|
|
try:
|
|
# Decode and validate refresh token
|
|
payload = decode_token(refresh_token)
|
|
|
|
# Verify it's a refresh token
|
|
if payload.get('type') != 'refresh':
|
|
return error_response(
|
|
error='Invalid token type',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Get user
|
|
user_id = payload.get('user_id')
|
|
account_id = payload.get('account_id')
|
|
|
|
try:
|
|
user = User.objects.get(id=user_id)
|
|
except User.DoesNotExist:
|
|
return error_response(
|
|
error='User not found',
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
request=request
|
|
)
|
|
|
|
# Get account
|
|
account_id = payload.get('account_id')
|
|
account = None
|
|
if account_id:
|
|
try:
|
|
account = Account.objects.get(id=account_id)
|
|
except Account.DoesNotExist:
|
|
pass
|
|
|
|
if not account:
|
|
account = getattr(user, 'account', None)
|
|
|
|
# Generate new access token
|
|
access_token = generate_access_token(user, account)
|
|
access_expires_at = get_token_expiry('access')
|
|
|
|
return success_response(
|
|
data={
|
|
'access': access_token,
|
|
'access_expires_at': access_expires_at.isoformat()
|
|
},
|
|
request=request
|
|
)
|
|
|
|
except jwt.InvalidTokenError as e:
|
|
return error_response(
|
|
error='Invalid or expired refresh token',
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'], permission_classes=[permissions.AllowAny])
|
|
def request_reset(self, request):
|
|
"""Request password reset - sends email with reset token."""
|
|
serializer = RequestPasswordResetSerializer(data=request.data)
|
|
if not serializer.is_valid():
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
email = serializer.validated_data['email']
|
|
|
|
try:
|
|
user = User.objects.get(email=email)
|
|
except User.DoesNotExist:
|
|
# Don't reveal if email exists - return success anyway
|
|
return success_response(
|
|
message='If an account with that email exists, a password reset link has been sent.',
|
|
request=request
|
|
)
|
|
|
|
# Generate secure token
|
|
import secrets
|
|
token = secrets.token_urlsafe(32)
|
|
|
|
# Create reset token (expires in 1 hour)
|
|
from django.utils import timezone
|
|
from datetime import timedelta
|
|
expires_at = timezone.now() + timedelta(hours=1)
|
|
|
|
PasswordResetToken.objects.create(
|
|
user=user,
|
|
token=token,
|
|
expires_at=expires_at
|
|
)
|
|
|
|
# Send password reset email using the email service
|
|
try:
|
|
from igny8_core.business.billing.services.email_service import send_password_reset_email
|
|
send_password_reset_email(user, token)
|
|
except Exception as e:
|
|
# Fallback to Django's send_mail if email service fails
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"Failed to send password reset email via email service: {e}")
|
|
|
|
from django.core.mail import send_mail
|
|
from django.conf import settings
|
|
|
|
frontend_url = getattr(settings, 'FRONTEND_URL', 'https://app.igny8.com')
|
|
reset_url = f"{frontend_url}/reset-password?token={token}"
|
|
|
|
send_mail(
|
|
subject='Reset Your IGNY8 Password',
|
|
message=f'Click the following link to reset your password: {reset_url}\n\nThis link expires in 1 hour.',
|
|
from_email=getattr(settings, 'DEFAULT_FROM_EMAIL', 'noreply@igny8.com'),
|
|
recipient_list=[user.email],
|
|
fail_silently=False,
|
|
)
|
|
|
|
return success_response(
|
|
message='If an account with that email exists, a password reset link has been sent.',
|
|
request=request
|
|
)
|
|
|
|
@action(detail=False, methods=['post'], permission_classes=[permissions.AllowAny])
|
|
def reset_password(self, request):
|
|
"""Reset password using reset token."""
|
|
serializer = ResetPasswordSerializer(data=request.data)
|
|
if not serializer.is_valid():
|
|
return error_response(
|
|
error='Validation failed',
|
|
errors=serializer.errors,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
token = serializer.validated_data['token']
|
|
new_password = serializer.validated_data['new_password']
|
|
|
|
try:
|
|
reset_token = PasswordResetToken.objects.get(token=token)
|
|
except PasswordResetToken.DoesNotExist:
|
|
return error_response(
|
|
error='Invalid reset token',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Check if token is valid
|
|
if not reset_token.is_valid():
|
|
return error_response(
|
|
error='Reset token has expired or has already been used',
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request=request
|
|
)
|
|
|
|
# Update password
|
|
user = reset_token.user
|
|
user.set_password(new_password)
|
|
user.save()
|
|
|
|
# Mark token as used
|
|
reset_token.used = True
|
|
reset_token.save()
|
|
|
|
return success_response(
|
|
message='Password has been reset successfully',
|
|
request=request
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# CSV Import/Export Views for Admin
|
|
# ============================================================================
|
|
|
|
from django.http import HttpResponse, JsonResponse
|
|
from django.contrib.admin.views.decorators import staff_member_required
|
|
from django.views.decorators.http import require_http_methods
|
|
import csv
|
|
import io
|
|
|
|
|
|
@staff_member_required
|
|
@require_http_methods(["GET"])
|
|
def industry_csv_template(request):
|
|
"""Download CSV template for Industry import"""
|
|
response = HttpResponse(content_type='text/csv')
|
|
response['Content-Disposition'] = 'attachment; filename="industry_template.csv"'
|
|
|
|
writer = csv.writer(response)
|
|
writer.writerow(['name', 'description', 'is_active'])
|
|
writer.writerow(['Technology', 'Technology industry', 'true'])
|
|
writer.writerow(['Healthcare', 'Healthcare and medical services', 'true'])
|
|
|
|
return response
|
|
|
|
|
|
@staff_member_required
|
|
@require_http_methods(["POST"])
|
|
def industry_csv_import(request):
|
|
"""Import industries from CSV"""
|
|
if not request.FILES.get('csv_file'):
|
|
return JsonResponse({'success': False, 'error': 'No CSV file provided'}, status=400)
|
|
|
|
csv_file = request.FILES['csv_file']
|
|
decoded_file = csv_file.read().decode('utf-8')
|
|
io_string = io.StringIO(decoded_file)
|
|
reader = csv.DictReader(io_string)
|
|
|
|
created = 0
|
|
updated = 0
|
|
errors = []
|
|
|
|
from django.utils.text import slugify
|
|
|
|
for row_num, row in enumerate(reader, start=2):
|
|
try:
|
|
is_active = row.get('is_active', 'true').lower() in ['true', '1', 'yes']
|
|
slug = slugify(row['name'])
|
|
|
|
industry, created_flag = Industry.objects.update_or_create(
|
|
name=row['name'],
|
|
defaults={
|
|
'slug': slug,
|
|
'description': row.get('description', ''),
|
|
'is_active': is_active
|
|
}
|
|
)
|
|
if created_flag:
|
|
created += 1
|
|
else:
|
|
updated += 1
|
|
except Exception as e:
|
|
errors.append(f"Row {row_num}: {str(e)}")
|
|
|
|
return JsonResponse({
|
|
'success': True,
|
|
'created': created,
|
|
'updated': updated,
|
|
'errors': errors
|
|
})
|
|
|
|
|
|
@staff_member_required
|
|
@require_http_methods(["GET"])
|
|
def industrysector_csv_template(request):
|
|
"""Download CSV template for IndustrySector import"""
|
|
response = HttpResponse(content_type='text/csv')
|
|
response['Content-Disposition'] = 'attachment; filename="industrysector_template.csv"'
|
|
|
|
writer = csv.writer(response)
|
|
writer.writerow(['name', 'industry', 'description', 'is_active'])
|
|
writer.writerow(['Software Development', 'Technology', 'Software and app development', 'true'])
|
|
writer.writerow(['Healthcare IT', 'Healthcare', 'Healthcare information technology', 'true'])
|
|
|
|
return response
|
|
|
|
|
|
@staff_member_required
|
|
@require_http_methods(["POST"])
|
|
def industrysector_csv_import(request):
|
|
"""Import industry sectors from CSV"""
|
|
if not request.FILES.get('csv_file'):
|
|
return JsonResponse({'success': False, 'error': 'No CSV file provided'}, status=400)
|
|
|
|
csv_file = request.FILES['csv_file']
|
|
decoded_file = csv_file.read().decode('utf-8')
|
|
io_string = io.StringIO(decoded_file)
|
|
reader = csv.DictReader(io_string)
|
|
|
|
created = 0
|
|
updated = 0
|
|
errors = []
|
|
|
|
from django.utils.text import slugify
|
|
|
|
for row_num, row in enumerate(reader, start=2):
|
|
try:
|
|
is_active = row.get('is_active', 'true').lower() in ['true', '1', 'yes']
|
|
slug = slugify(row['name'])
|
|
|
|
# Find industry by name
|
|
try:
|
|
industry = Industry.objects.get(name=row['industry'])
|
|
except Industry.DoesNotExist:
|
|
errors.append(f"Row {row_num}: Industry '{row['industry']}' not found")
|
|
continue
|
|
|
|
sector, created_flag = IndustrySector.objects.update_or_create(
|
|
name=row['name'],
|
|
industry=industry,
|
|
defaults={
|
|
'slug': slug,
|
|
'description': row.get('description', ''),
|
|
'is_active': is_active
|
|
}
|
|
)
|
|
if created_flag:
|
|
created += 1
|
|
else:
|
|
updated += 1
|
|
except Exception as e:
|
|
errors.append(f"Row {row_num}: {str(e)}")
|
|
|
|
return JsonResponse({
|
|
'success': True,
|
|
'created': created,
|
|
'updated': updated,
|
|
'errors': errors
|
|
})
|
|
|
|
|
|
@staff_member_required
|
|
@require_http_methods(["GET"])
|
|
def seedkeyword_csv_template(request):
|
|
"""Download CSV template for SeedKeyword import"""
|
|
response = HttpResponse(content_type='text/csv')
|
|
response['Content-Disposition'] = 'attachment; filename="seedkeyword_template.csv"'
|
|
|
|
writer = csv.writer(response)
|
|
writer.writerow(['keyword', 'industry', 'sector', 'volume', 'difficulty', 'country', 'is_active'])
|
|
writer.writerow(['python programming', 'Technology', 'Software Development', '10000', '45', 'US', 'true'])
|
|
writer.writerow(['medical software', 'Healthcare', 'Healthcare IT', '5000', '60', 'CA', 'true'])
|
|
|
|
return response
|
|
|
|
|
|
@staff_member_required
|
|
@require_http_methods(["POST"])
|
|
def seedkeyword_csv_import(request):
|
|
"""Import seed keywords from CSV"""
|
|
if not request.FILES.get('csv_file'):
|
|
return JsonResponse({'success': False, 'error': 'No CSV file provided'}, status=400)
|
|
|
|
csv_file = request.FILES['csv_file']
|
|
decoded_file = csv_file.read().decode('utf-8')
|
|
io_string = io.StringIO(decoded_file)
|
|
reader = csv.DictReader(io_string)
|
|
|
|
created = 0
|
|
updated = 0
|
|
errors = []
|
|
|
|
for row_num, row in enumerate(reader, start=2):
|
|
try:
|
|
is_active = row.get('is_active', 'true').lower() in ['true', '1', 'yes']
|
|
|
|
# Find industry and sector by name
|
|
try:
|
|
industry = Industry.objects.get(name=row['industry'])
|
|
except Industry.DoesNotExist:
|
|
errors.append(f"Row {row_num}: Industry '{row['industry']}' not found")
|
|
continue
|
|
|
|
try:
|
|
sector = IndustrySector.objects.get(name=row['sector'], industry=industry)
|
|
except IndustrySector.DoesNotExist:
|
|
errors.append(f"Row {row_num}: Sector '{row['sector']}' not found in industry '{row['industry']}'")
|
|
continue
|
|
|
|
keyword, created_flag = SeedKeyword.objects.update_or_create(
|
|
keyword=row['keyword'],
|
|
industry=industry,
|
|
sector=sector,
|
|
defaults={
|
|
'volume': int(row.get('volume', 0)),
|
|
'difficulty': int(row.get('difficulty', 0)),
|
|
'country': row.get('country', 'US'),
|
|
'is_active': is_active
|
|
}
|
|
)
|
|
if created_flag:
|
|
created += 1
|
|
else:
|
|
updated += 1
|
|
except Exception as e:
|
|
errors.append(f"Row {row_num}: {str(e)}")
|
|
|
|
return JsonResponse({
|
|
'success': True,
|
|
'created': created,
|
|
'updated': updated,
|
|
'errors': errors
|
|
})
|
|
|