style: implement custom color schemes and gradients for account section, enhancing visual hierarchy
537 lines
23 KiB
Python
537 lines
23 KiB
Python
"""
|
|
Authentication Serializers
|
|
"""
|
|
from rest_framework import serializers
|
|
from django.contrib.auth.password_validation import validate_password
|
|
from .models import User, Account, Plan, Subscription, Site, Sector, SiteUserAccess, Industry, IndustrySector, SeedKeyword
|
|
|
|
|
|
class PlanSerializer(serializers.ModelSerializer):
|
|
class Meta:
|
|
model = Plan
|
|
fields = [
|
|
'id', 'name', 'slug', 'price', 'billing_cycle', 'annual_discount_percent',
|
|
'is_featured', 'features', 'is_active',
|
|
'max_users', 'max_sites', 'max_industries', 'max_author_profiles',
|
|
'max_keywords', 'max_clusters',
|
|
'max_content_ideas', 'max_content_words',
|
|
'max_images_basic', 'max_images_premium', 'max_image_prompts',
|
|
'included_credits', 'extra_credit_price', 'allow_credit_topup',
|
|
'auto_credit_topup_threshold', 'auto_credit_topup_amount',
|
|
'stripe_product_id', 'stripe_price_id', 'credits_per_month'
|
|
]
|
|
|
|
|
|
class SubscriptionSerializer(serializers.ModelSerializer):
|
|
"""Serializer for Subscription model."""
|
|
account_name = serializers.CharField(source='account.name', read_only=True)
|
|
account_slug = serializers.CharField(source='account.slug', read_only=True)
|
|
|
|
class Meta:
|
|
model = Subscription
|
|
fields = [
|
|
'id', 'account', 'account_name', 'account_slug',
|
|
'stripe_subscription_id', 'payment_method', 'external_payment_id',
|
|
'status', 'current_period_start', 'current_period_end',
|
|
'cancel_at_period_end',
|
|
'created_at', 'updated_at'
|
|
]
|
|
read_only_fields = ['created_at', 'updated_at']
|
|
|
|
|
|
class AccountSerializer(serializers.ModelSerializer):
|
|
plan = PlanSerializer(read_only=True)
|
|
plan_id = serializers.PrimaryKeyRelatedField(queryset=Plan.objects.filter(is_active=True), write_only=True, source='plan', required=False)
|
|
subscription = SubscriptionSerializer(read_only=True, allow_null=True)
|
|
|
|
def validate_plan_id(self, value):
|
|
"""Validate plan_id is provided during creation."""
|
|
if self.instance is None and not value:
|
|
raise serializers.ValidationError("plan_id is required when creating an account.")
|
|
return value
|
|
|
|
class Meta:
|
|
model = Account
|
|
fields = [
|
|
'id', 'name', 'slug', 'owner', 'plan', 'plan_id',
|
|
'credits', 'status', 'payment_method',
|
|
'subscription', 'created_at'
|
|
]
|
|
read_only_fields = ['owner', 'created_at']
|
|
|
|
|
|
class SiteSerializer(serializers.ModelSerializer):
|
|
"""Serializer for Site model."""
|
|
sectors_count = serializers.SerializerMethodField()
|
|
active_sectors_count = serializers.SerializerMethodField()
|
|
selected_sectors = serializers.SerializerMethodField()
|
|
can_add_sectors = serializers.SerializerMethodField()
|
|
industry_name = serializers.CharField(source='industry.name', read_only=True)
|
|
industry_slug = serializers.CharField(source='industry.slug', read_only=True)
|
|
# Override domain field to use CharField instead of URLField to avoid premature validation
|
|
domain = serializers.CharField(required=False, allow_blank=True, allow_null=True)
|
|
|
|
class Meta:
|
|
model = Site
|
|
fields = [
|
|
'id', 'name', 'slug', 'domain', 'description',
|
|
'industry', 'industry_name', 'industry_slug',
|
|
'is_active', 'status',
|
|
'site_type', 'hosting_type', 'seo_metadata',
|
|
'sectors_count', 'active_sectors_count', 'selected_sectors',
|
|
'can_add_sectors',
|
|
'created_at', 'updated_at'
|
|
]
|
|
read_only_fields = ['created_at', 'updated_at', 'account']
|
|
# Explicitly specify required fields for clarity
|
|
extra_kwargs = {
|
|
'industry': {'required': True, 'error_messages': {'required': 'Industry is required when creating a site.'}},
|
|
}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""Allow partial updates for PATCH requests."""
|
|
super().__init__(*args, **kwargs)
|
|
# Make slug optional - it will be auto-generated from name if not provided
|
|
if 'slug' in self.fields:
|
|
self.fields['slug'].required = False
|
|
# For partial updates (PATCH), make name and industry optional
|
|
if self.partial:
|
|
if 'name' in self.fields:
|
|
self.fields['name'].required = False
|
|
if 'industry' in self.fields:
|
|
self.fields['industry'].required = False
|
|
|
|
def validate_domain(self, value):
|
|
"""Ensure domain has https:// protocol.
|
|
- If domain has https://, keep it as is
|
|
- If domain has http://, replace with https://
|
|
- If domain has no protocol, add https://
|
|
- Validates that the final URL is valid
|
|
"""
|
|
# Allow empty/None values
|
|
if not value or value.strip() == '':
|
|
return None
|
|
|
|
value = value.strip()
|
|
|
|
# If it already starts with https://, keep it as is
|
|
if value.startswith('https://'):
|
|
normalized = value
|
|
# If it starts with http://, replace with https://
|
|
elif value.startswith('http://'):
|
|
normalized = value.replace('http://', 'https://', 1)
|
|
# Otherwise, add https://
|
|
else:
|
|
normalized = f'https://{value}'
|
|
|
|
# Validate that the normalized URL is a valid URL format
|
|
from django.core.validators import URLValidator
|
|
from django.core.exceptions import ValidationError
|
|
|
|
validator = URLValidator()
|
|
try:
|
|
validator(normalized)
|
|
except ValidationError:
|
|
raise serializers.ValidationError("Enter a valid URL or domain name.")
|
|
|
|
return normalized
|
|
|
|
def validate(self, attrs):
|
|
"""Auto-generate slug from name if not provided."""
|
|
# Auto-generate slug from name if slug is not provided
|
|
if 'slug' not in attrs or not attrs.get('slug'):
|
|
if 'name' in attrs and attrs['name']:
|
|
from django.utils.text import slugify
|
|
attrs['slug'] = slugify(attrs['name'])
|
|
return attrs
|
|
|
|
def get_sectors_count(self, obj):
|
|
"""Get total sectors count."""
|
|
return obj.sectors.count()
|
|
|
|
def get_active_sectors_count(self, obj):
|
|
"""Get active sectors count."""
|
|
return obj.sectors.filter(is_active=True).count()
|
|
|
|
def get_selected_sectors(self, obj):
|
|
"""Get list of selected sector IDs."""
|
|
return list(obj.sectors.filter(is_active=True).values_list('id', flat=True))
|
|
|
|
def get_can_add_sectors(self, obj):
|
|
"""Check if site can add more sectors (max 5)."""
|
|
return obj.can_add_sector()
|
|
|
|
|
|
class IndustrySectorSerializer(serializers.ModelSerializer):
|
|
"""Serializer for IndustrySector model."""
|
|
class Meta:
|
|
model = IndustrySector
|
|
fields = [
|
|
'id', 'industry', 'name', 'slug', 'description',
|
|
'is_active',
|
|
'created_at', 'updated_at'
|
|
]
|
|
read_only_fields = ['created_at', 'updated_at', 'id', 'industry']
|
|
|
|
|
|
class IndustrySerializer(serializers.ModelSerializer):
|
|
"""Serializer for Industry model."""
|
|
sectors = IndustrySectorSerializer(many=True, read_only=True)
|
|
sectors_count = serializers.SerializerMethodField()
|
|
|
|
class Meta:
|
|
model = Industry
|
|
fields = [
|
|
'id', 'name', 'slug', 'description', 'is_active',
|
|
'sectors', 'sectors_count',
|
|
'created_at', 'updated_at'
|
|
]
|
|
read_only_fields = ['created_at', 'updated_at']
|
|
|
|
def get_sectors_count(self, obj):
|
|
"""Get active sectors count."""
|
|
return obj.sectors.filter(is_active=True).count()
|
|
|
|
|
|
class SectorSerializer(serializers.ModelSerializer):
|
|
"""Serializer for Sector model."""
|
|
site_name = serializers.CharField(source='site.name', read_only=True)
|
|
industry_sector_name = serializers.CharField(source='industry_sector.name', read_only=True)
|
|
industry_sector_slug = serializers.CharField(source='industry_sector.slug', read_only=True)
|
|
industry_name = serializers.SerializerMethodField()
|
|
industry_slug = serializers.SerializerMethodField()
|
|
keywords_count = serializers.SerializerMethodField()
|
|
clusters_count = serializers.SerializerMethodField()
|
|
|
|
class Meta:
|
|
model = Sector
|
|
fields = [
|
|
'id', 'site', 'site_name', 'industry_sector', 'industry_sector_name',
|
|
'industry_sector_slug', 'industry_name', 'industry_slug',
|
|
'name', 'slug', 'description',
|
|
'is_active', 'status', 'keywords_count', 'clusters_count',
|
|
'created_at', 'updated_at'
|
|
]
|
|
read_only_fields = ['created_at', 'updated_at', 'account']
|
|
|
|
def get_industry_name(self, obj):
|
|
"""Get industry name from industry_sector."""
|
|
return obj.industry_sector.industry.name if obj.industry_sector else None
|
|
|
|
def get_industry_slug(self, obj):
|
|
"""Get industry slug from industry_sector."""
|
|
return obj.industry_sector.industry.slug if obj.industry_sector else None
|
|
|
|
def get_keywords_count(self, obj):
|
|
"""Get keywords count in this sector."""
|
|
# Using the related name from Keywords model
|
|
return getattr(obj, 'keywords_set', obj.keywords_set).count()
|
|
|
|
def get_clusters_count(self, obj):
|
|
"""Get clusters count in this sector."""
|
|
# Using the related name from Clusters model
|
|
return getattr(obj, 'clusters_set', obj.clusters_set).count()
|
|
|
|
|
|
class SiteUserAccessSerializer(serializers.ModelSerializer):
|
|
"""Serializer for SiteUserAccess model."""
|
|
user_email = serializers.CharField(source='user.email', read_only=True)
|
|
user_name = serializers.CharField(source='user.username', read_only=True)
|
|
site_name = serializers.CharField(source='site.name', read_only=True)
|
|
|
|
class Meta:
|
|
model = SiteUserAccess
|
|
fields = ['id', 'user', 'user_email', 'user_name', 'site', 'site_name', 'granted_at', 'granted_by']
|
|
read_only_fields = ['granted_at']
|
|
|
|
|
|
from igny8_core.business.billing.models import PAYMENT_METHOD_CHOICES
|
|
|
|
|
|
class UserSerializer(serializers.ModelSerializer):
|
|
account = AccountSerializer(read_only=True)
|
|
accessible_sites = serializers.SerializerMethodField()
|
|
|
|
class Meta:
|
|
model = User
|
|
fields = ['id', 'username', 'email', 'role', 'account', 'accessible_sites', 'created_at']
|
|
read_only_fields = ['created_at']
|
|
|
|
def get_accessible_sites(self, obj):
|
|
"""Get list of sites user can access."""
|
|
sites = obj.get_accessible_sites()
|
|
return SiteSerializer(sites, many=True).data
|
|
|
|
|
|
class RegisterSerializer(serializers.Serializer):
|
|
"""Serializer for user registration."""
|
|
email = serializers.EmailField()
|
|
username = serializers.CharField(max_length=150, required=False)
|
|
password = serializers.CharField(write_only=True, validators=[validate_password])
|
|
password_confirm = serializers.CharField(write_only=True)
|
|
first_name = serializers.CharField(max_length=150, required=False, allow_blank=True)
|
|
last_name = serializers.CharField(max_length=150, required=False, allow_blank=True)
|
|
account_name = serializers.CharField(max_length=255, required=False, allow_blank=True, allow_null=True, default=None)
|
|
plan_id = serializers.PrimaryKeyRelatedField(
|
|
queryset=Plan.objects.filter(is_active=True),
|
|
required=False,
|
|
allow_null=True,
|
|
default=None
|
|
)
|
|
plan_slug = serializers.CharField(max_length=50, required=False)
|
|
payment_method = serializers.ChoiceField(
|
|
choices=[choice[0] for choice in PAYMENT_METHOD_CHOICES],
|
|
default='bank_transfer',
|
|
required=False
|
|
)
|
|
# Billing information fields
|
|
billing_email = serializers.EmailField(required=False, allow_blank=True)
|
|
billing_address_line1 = serializers.CharField(max_length=255, required=False, allow_blank=True)
|
|
billing_address_line2 = serializers.CharField(max_length=255, required=False, allow_blank=True)
|
|
billing_city = serializers.CharField(max_length=100, required=False, allow_blank=True)
|
|
billing_state = serializers.CharField(max_length=100, required=False, allow_blank=True)
|
|
billing_postal_code = serializers.CharField(max_length=20, required=False, allow_blank=True)
|
|
billing_country = serializers.CharField(max_length=2, required=False, allow_blank=True)
|
|
tax_id = serializers.CharField(max_length=100, required=False, allow_blank=True)
|
|
|
|
def validate(self, attrs):
|
|
if attrs['password'] != attrs['password_confirm']:
|
|
raise serializers.ValidationError({"password": "Passwords do not match"})
|
|
|
|
# Convert empty strings to None for optional fields
|
|
if 'account_name' in attrs and attrs.get('account_name') == '':
|
|
attrs['account_name'] = None
|
|
if 'plan_id' in attrs and attrs.get('plan_id') == '':
|
|
attrs['plan_id'] = None
|
|
|
|
# Validate billing fields for paid plans
|
|
plan_slug = attrs.get('plan_slug')
|
|
paid_plans = ['starter', 'growth', 'scale']
|
|
if plan_slug and plan_slug in paid_plans:
|
|
# Require billing_country for paid plans
|
|
if not attrs.get('billing_country'):
|
|
raise serializers.ValidationError({
|
|
"billing_country": "Billing country is required for paid plans."
|
|
})
|
|
# Require payment_method for paid plans
|
|
if not attrs.get('payment_method'):
|
|
raise serializers.ValidationError({
|
|
"payment_method": "Payment method is required for paid plans."
|
|
})
|
|
|
|
return attrs
|
|
|
|
def create(self, validated_data):
|
|
from django.db import transaction
|
|
from igny8_core.business.billing.models import CreditTransaction
|
|
from igny8_core.auth.models import Subscription
|
|
from igny8_core.business.billing.models import AccountPaymentMethod
|
|
from igny8_core.business.billing.services.invoice_service import InvoiceService
|
|
from django.utils import timezone
|
|
from datetime import timedelta
|
|
|
|
with transaction.atomic():
|
|
plan_slug = validated_data.get('plan_slug')
|
|
paid_plans = ['starter', 'growth', 'scale']
|
|
|
|
if plan_slug and plan_slug in paid_plans:
|
|
try:
|
|
plan = Plan.objects.get(slug=plan_slug, is_active=True)
|
|
except Plan.DoesNotExist:
|
|
raise serializers.ValidationError({
|
|
"plan": f"Plan '{plan_slug}' not available. Please contact support."
|
|
})
|
|
account_status = 'pending_payment'
|
|
initial_credits = 0
|
|
billing_period_start = timezone.now()
|
|
# simple monthly cycle; if annual needed, extend here
|
|
billing_period_end = billing_period_start + timedelta(days=30)
|
|
else:
|
|
try:
|
|
plan = Plan.objects.get(slug='free', is_active=True)
|
|
except Plan.DoesNotExist:
|
|
raise serializers.ValidationError({
|
|
"plan": "Free plan not configured. Please contact support."
|
|
})
|
|
account_status = 'trial'
|
|
initial_credits = plan.get_effective_credits_per_month()
|
|
billing_period_start = None
|
|
billing_period_end = None
|
|
|
|
# Generate account name if not provided
|
|
account_name = validated_data.get('account_name')
|
|
if not account_name:
|
|
first_name = validated_data.get('first_name', '')
|
|
last_name = validated_data.get('last_name', '')
|
|
if first_name or last_name:
|
|
account_name = f"{first_name} {last_name}".strip() or \
|
|
validated_data['email'].split('@')[0]
|
|
else:
|
|
account_name = validated_data['email'].split('@')[0]
|
|
|
|
# Generate username if not provided
|
|
username = validated_data.get('username')
|
|
if not username:
|
|
username = validated_data['email'].split('@')[0]
|
|
# Ensure username is unique
|
|
base_username = username
|
|
counter = 1
|
|
while User.objects.filter(username=username).exists():
|
|
username = f"{base_username}{counter}"
|
|
counter += 1
|
|
|
|
# Create user first without account (User.account is nullable)
|
|
user = User.objects.create_user(
|
|
username=username,
|
|
email=validated_data['email'],
|
|
password=validated_data['password'],
|
|
first_name=validated_data.get('first_name', ''),
|
|
last_name=validated_data.get('last_name', ''),
|
|
account=None, # Will be set after account creation
|
|
role='owner'
|
|
)
|
|
|
|
# Generate unique slug for account
|
|
base_slug = account_name.lower().replace(' ', '-').replace('_', '-')[:50] or 'account'
|
|
slug = base_slug
|
|
counter = 1
|
|
while Account.objects.filter(slug=slug).exists():
|
|
slug = f"{base_slug}-{counter}"
|
|
counter += 1
|
|
|
|
# Create account with status and credits seeded (0 for paid pending)
|
|
account = Account.objects.create(
|
|
name=account_name,
|
|
slug=slug,
|
|
owner=user,
|
|
plan=plan,
|
|
credits=initial_credits,
|
|
status=account_status,
|
|
payment_method=validated_data.get('payment_method') or 'bank_transfer',
|
|
# Save billing information
|
|
billing_email=validated_data.get('billing_email', '') or validated_data.get('email', ''),
|
|
billing_address_line1=validated_data.get('billing_address_line1', ''),
|
|
billing_address_line2=validated_data.get('billing_address_line2', ''),
|
|
billing_city=validated_data.get('billing_city', ''),
|
|
billing_state=validated_data.get('billing_state', ''),
|
|
billing_postal_code=validated_data.get('billing_postal_code', ''),
|
|
billing_country=validated_data.get('billing_country', ''),
|
|
tax_id=validated_data.get('tax_id', ''),
|
|
)
|
|
|
|
# Log initial credit transaction only for free/trial accounts with credits
|
|
if initial_credits > 0:
|
|
CreditTransaction.objects.create(
|
|
account=account,
|
|
transaction_type='subscription',
|
|
amount=initial_credits,
|
|
balance_after=initial_credits,
|
|
description=f'Free plan credits from {plan.name}',
|
|
metadata={
|
|
'plan_slug': plan.slug,
|
|
'registration': True,
|
|
'trial': True
|
|
}
|
|
)
|
|
|
|
# Update user to reference the new account
|
|
user.account = account
|
|
user.save()
|
|
|
|
# For paid plans, create subscription, invoice, and default payment method
|
|
if plan_slug and plan_slug in paid_plans:
|
|
payment_method = validated_data.get('payment_method', 'bank_transfer')
|
|
|
|
subscription = Subscription.objects.create(
|
|
account=account,
|
|
plan=plan,
|
|
status='pending_payment',
|
|
external_payment_id=None,
|
|
current_period_start=billing_period_start,
|
|
current_period_end=billing_period_end,
|
|
cancel_at_period_end=False,
|
|
)
|
|
# Create pending invoice for the first period
|
|
InvoiceService.create_subscription_invoice(
|
|
subscription=subscription,
|
|
billing_period_start=billing_period_start,
|
|
billing_period_end=billing_period_end,
|
|
)
|
|
# Create AccountPaymentMethod with selected payment method
|
|
payment_method_display_names = {
|
|
'stripe': 'Credit/Debit Card (Stripe)',
|
|
'paypal': 'PayPal',
|
|
'bank_transfer': 'Bank Transfer (Manual)',
|
|
'local_wallet': 'Mobile Wallet (Manual)',
|
|
}
|
|
AccountPaymentMethod.objects.create(
|
|
account=account,
|
|
type=payment_method,
|
|
display_name=payment_method_display_names.get(payment_method, payment_method.title()),
|
|
is_default=True,
|
|
is_enabled=True,
|
|
is_verified=False,
|
|
instructions='Please complete payment and confirm with your transaction reference.',
|
|
)
|
|
|
|
return user
|
|
|
|
|
|
class LoginSerializer(serializers.Serializer):
|
|
"""Serializer for user login."""
|
|
email = serializers.EmailField()
|
|
password = serializers.CharField(write_only=True)
|
|
|
|
|
|
class ChangePasswordSerializer(serializers.Serializer):
|
|
"""Serializer for password change."""
|
|
old_password = serializers.CharField(write_only=True)
|
|
new_password = serializers.CharField(write_only=True, validators=[validate_password])
|
|
new_password_confirm = serializers.CharField(write_only=True)
|
|
|
|
def validate(self, attrs):
|
|
if attrs['new_password'] != attrs['new_password_confirm']:
|
|
raise serializers.ValidationError({"new_password": "Passwords do not match"})
|
|
return attrs
|
|
|
|
|
|
class RefreshTokenSerializer(serializers.Serializer):
|
|
"""Serializer for token refresh."""
|
|
refresh = serializers.CharField(required=True)
|
|
|
|
|
|
class RequestPasswordResetSerializer(serializers.Serializer):
|
|
"""Serializer for password reset request."""
|
|
email = serializers.EmailField(required=True)
|
|
|
|
|
|
class ResetPasswordSerializer(serializers.Serializer):
|
|
"""Serializer for password reset."""
|
|
token = serializers.CharField(required=True)
|
|
new_password = serializers.CharField(write_only=True, validators=[validate_password])
|
|
new_password_confirm = serializers.CharField(write_only=True)
|
|
|
|
def validate(self, attrs):
|
|
if attrs['new_password'] != attrs['new_password_confirm']:
|
|
raise serializers.ValidationError({"new_password": "Passwords do not match"})
|
|
return attrs
|
|
|
|
|
|
class SeedKeywordSerializer(serializers.ModelSerializer):
|
|
"""Serializer for SeedKeyword model."""
|
|
industry_name = serializers.CharField(source='industry.name', read_only=True)
|
|
industry_slug = serializers.CharField(source='industry.slug', read_only=True)
|
|
sector_name = serializers.CharField(source='sector.name', read_only=True)
|
|
sector_slug = serializers.CharField(source='sector.slug', read_only=True)
|
|
intent_display = serializers.CharField(source='get_intent_display', read_only=True)
|
|
|
|
class Meta:
|
|
model = SeedKeyword
|
|
fields = [
|
|
'id', 'keyword', 'industry', 'industry_name', 'industry_slug',
|
|
'sector', 'sector_name', 'sector_slug',
|
|
'volume', 'difficulty', 'intent', 'intent_display',
|
|
'is_active', 'created_at', 'updated_at'
|
|
]
|
|
read_only_fields = ['created_at', 'updated_at']
|