Backeup configs & cleanup of files and db
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
# Billing tests
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
"""
|
||||
Concurrency tests for payment approval
|
||||
Tests race conditions and concurrent approval attempts
|
||||
"""
|
||||
import pytest
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db import transaction
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from decimal import Decimal
|
||||
from igny8_core.business.billing.models import (
|
||||
Invoice, Payment, Subscription, Plan, Account
|
||||
)
|
||||
from igny8_core.business.billing.views import approve_payment
|
||||
from unittest.mock import Mock
|
||||
import threading
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class PaymentApprovalConcurrencyTest(TransactionTestCase):
|
||||
"""Test concurrent payment approval scenarios"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
# Create admin user
|
||||
self.admin = User.objects.create_user(
|
||||
email='admin@test.com',
|
||||
password='testpass123',
|
||||
is_staff=True
|
||||
)
|
||||
|
||||
# Create account
|
||||
self.account = Account.objects.create(
|
||||
name='Test Account',
|
||||
owner=self.admin,
|
||||
credit_balance=0
|
||||
)
|
||||
|
||||
# Create plan
|
||||
self.plan = Plan.objects.create(
|
||||
name='Test Plan',
|
||||
slug='test-plan',
|
||||
price=Decimal('100.00'),
|
||||
currency='USD',
|
||||
billing_period='monthly',
|
||||
included_credits=1000
|
||||
)
|
||||
|
||||
# Create subscription
|
||||
self.subscription = Subscription.objects.create(
|
||||
account=self.account,
|
||||
plan=self.plan,
|
||||
status='pending_payment'
|
||||
)
|
||||
|
||||
# Create invoice
|
||||
self.invoice = Invoice.objects.create(
|
||||
account=self.account,
|
||||
invoice_number='INV-TEST-001',
|
||||
status='pending',
|
||||
subtotal=Decimal('100.00'),
|
||||
total_amount=Decimal('100.00'),
|
||||
currency='USD',
|
||||
invoice_type='subscription'
|
||||
)
|
||||
|
||||
# Create payment
|
||||
self.payment = Payment.objects.create(
|
||||
account=self.account,
|
||||
invoice=self.invoice,
|
||||
amount=Decimal('100.00'),
|
||||
currency='USD',
|
||||
payment_method='bank_transfer',
|
||||
status='pending_approval',
|
||||
manual_reference='TEST-REF-001'
|
||||
)
|
||||
|
||||
def test_concurrent_approval_attempts(self):
|
||||
"""
|
||||
Test that only one concurrent approval succeeds
|
||||
Multiple admins trying to approve same payment simultaneously
|
||||
"""
|
||||
num_threads = 5
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
results = []
|
||||
|
||||
def approve_payment_thread(payment_id, admin_user):
|
||||
"""Thread worker to approve payment"""
|
||||
try:
|
||||
# Simulate approval logic with transaction
|
||||
with transaction.atomic():
|
||||
payment = Payment.objects.select_for_update().get(id=payment_id)
|
||||
|
||||
# Check if already approved
|
||||
if payment.status == 'succeeded':
|
||||
return {'success': False, 'reason': 'already_approved'}
|
||||
|
||||
# Approve payment
|
||||
payment.status = 'succeeded'
|
||||
payment.approved_by = admin_user
|
||||
payment.save()
|
||||
|
||||
# Update invoice
|
||||
invoice = payment.invoice
|
||||
invoice.status = 'paid'
|
||||
invoice.save()
|
||||
|
||||
return {'success': True}
|
||||
|
||||
except Exception as e:
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
# Create multiple threads attempting approval
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
for i in range(num_threads):
|
||||
future = executor.submit(approve_payment_thread, self.payment.id, self.admin)
|
||||
futures.append(future)
|
||||
|
||||
# Collect results
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
if result.get('success'):
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
# Verify only one approval succeeded
|
||||
self.assertEqual(success_count, 1, "Only one approval should succeed")
|
||||
self.assertEqual(failure_count, num_threads - 1, "Other attempts should fail")
|
||||
|
||||
# Verify final state
|
||||
payment = Payment.objects.get(id=self.payment.id)
|
||||
self.assertEqual(payment.status, 'succeeded')
|
||||
|
||||
invoice = Invoice.objects.get(id=self.invoice.id)
|
||||
self.assertEqual(invoice.status, 'paid')
|
||||
|
||||
def test_payment_and_invoice_consistency(self):
|
||||
"""
|
||||
Test that payment and invoice remain consistent under concurrent operations
|
||||
"""
|
||||
def read_payment_invoice(payment_id):
|
||||
"""Read payment and invoice status"""
|
||||
payment = Payment.objects.get(id=payment_id)
|
||||
invoice = Invoice.objects.get(id=payment.invoice_id)
|
||||
return {
|
||||
'payment_status': payment.status,
|
||||
'invoice_status': invoice.status,
|
||||
'consistent': (
|
||||
(payment.status == 'succeeded' and invoice.status == 'paid') or
|
||||
(payment.status == 'pending_approval' and invoice.status == 'pending')
|
||||
)
|
||||
}
|
||||
|
||||
# Approve payment in one thread
|
||||
def approve():
|
||||
with transaction.atomic():
|
||||
payment = Payment.objects.select_for_update().get(id=self.payment.id)
|
||||
payment.status = 'succeeded'
|
||||
payment.save()
|
||||
|
||||
invoice = Invoice.objects.select_for_update().get(id=self.invoice.id)
|
||||
invoice.status = 'paid'
|
||||
invoice.save()
|
||||
|
||||
# Read state in parallel threads
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
# Start approval
|
||||
approval_future = executor.submit(approve)
|
||||
|
||||
# Multiple concurrent reads
|
||||
read_futures = [
|
||||
executor.submit(read_payment_invoice, self.payment.id)
|
||||
for _ in range(20)
|
||||
]
|
||||
|
||||
# Wait for approval
|
||||
approval_future.result()
|
||||
|
||||
# Collect read results
|
||||
for future in as_completed(read_futures):
|
||||
results.append(future.result())
|
||||
|
||||
# All reads should show consistent state
|
||||
for result in results:
|
||||
self.assertTrue(
|
||||
result['consistent'],
|
||||
f"Inconsistent state: payment={result['payment_status']}, invoice={result['invoice_status']}"
|
||||
)
|
||||
|
||||
def test_double_approval_prevention(self):
|
||||
"""
|
||||
Test that payment cannot be approved twice
|
||||
"""
|
||||
# First approval
|
||||
with transaction.atomic():
|
||||
payment = Payment.objects.select_for_update().get(id=self.payment.id)
|
||||
payment.status = 'succeeded'
|
||||
payment.approved_by = self.admin
|
||||
payment.save()
|
||||
|
||||
invoice = payment.invoice
|
||||
invoice.status = 'paid'
|
||||
invoice.save()
|
||||
|
||||
# Attempt second approval
|
||||
result = None
|
||||
try:
|
||||
with transaction.atomic():
|
||||
payment = Payment.objects.select_for_update().get(id=self.payment.id)
|
||||
|
||||
# Should detect already approved
|
||||
if payment.status == 'succeeded':
|
||||
result = 'already_approved'
|
||||
else:
|
||||
payment.status = 'succeeded'
|
||||
payment.save()
|
||||
result = 'approved'
|
||||
except Exception as e:
|
||||
result = f'error: {str(e)}'
|
||||
|
||||
self.assertEqual(result, 'already_approved', "Second approval should be prevented")
|
||||
|
||||
|
||||
class CreditTransactionConcurrencyTest(TransactionTestCase):
|
||||
"""Test concurrent credit additions/deductions"""
|
||||
|
||||
def setUp(self):
|
||||
self.admin = User.objects.create_user(
|
||||
email='admin@test.com',
|
||||
password='testpass123'
|
||||
)
|
||||
self.account = Account.objects.create(
|
||||
name='Test Account',
|
||||
owner=self.admin,
|
||||
credit_balance=1000
|
||||
)
|
||||
|
||||
def test_concurrent_credit_deductions(self):
|
||||
"""
|
||||
Test that concurrent credit deductions maintain correct balance
|
||||
"""
|
||||
initial_balance = self.account.credit_balance
|
||||
deduction_amount = 10
|
||||
num_operations = 20
|
||||
|
||||
def deduct_credits(account_id, amount):
|
||||
"""Deduct credits atomically"""
|
||||
from igny8_core.business.billing.models import CreditTransaction
|
||||
|
||||
with transaction.atomic():
|
||||
account = Account.objects.select_for_update().get(id=account_id)
|
||||
|
||||
# Check sufficient balance
|
||||
if account.credit_balance < amount:
|
||||
return {'success': False, 'reason': 'insufficient_credits'}
|
||||
|
||||
# Deduct credits
|
||||
account.credit_balance -= amount
|
||||
new_balance = account.credit_balance
|
||||
account.save()
|
||||
|
||||
# Record transaction
|
||||
CreditTransaction.objects.create(
|
||||
account=account,
|
||||
transaction_type='deduction',
|
||||
amount=-amount,
|
||||
balance_after=new_balance,
|
||||
description='Test deduction'
|
||||
)
|
||||
|
||||
return {'success': True, 'new_balance': new_balance}
|
||||
|
||||
# Concurrent deductions
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(deduct_credits, self.account.id, deduction_amount)
|
||||
for _ in range(num_operations)
|
||||
]
|
||||
|
||||
results = [future.result() for future in as_completed(futures)]
|
||||
|
||||
# Verify all succeeded
|
||||
success_count = sum(1 for r in results if r.get('success'))
|
||||
self.assertEqual(success_count, num_operations, "All deductions should succeed")
|
||||
|
||||
# Verify final balance
|
||||
self.account.refresh_from_db()
|
||||
expected_balance = initial_balance - (deduction_amount * num_operations)
|
||||
self.assertEqual(
|
||||
self.account.credit_balance,
|
||||
expected_balance,
|
||||
f"Final balance should be {expected_balance}"
|
||||
)
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
Test payment method filtering by country
|
||||
"""
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth import get_user_model
|
||||
from igny8_core.business.billing.models import PaymentMethodConfig
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class PaymentMethodFilteringTest(TestCase):
|
||||
"""Test payment method filtering by billing country"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create test payment method configs"""
|
||||
# Global methods (available everywhere)
|
||||
PaymentMethodConfig.objects.create(
|
||||
country_code='*',
|
||||
payment_method='stripe',
|
||||
display_name='Credit/Debit Card',
|
||||
is_enabled=True,
|
||||
sort_order=1,
|
||||
)
|
||||
PaymentMethodConfig.objects.create(
|
||||
country_code='*',
|
||||
payment_method='paypal',
|
||||
display_name='PayPal',
|
||||
is_enabled=True,
|
||||
sort_order=2,
|
||||
)
|
||||
|
||||
# Country-specific methods
|
||||
PaymentMethodConfig.objects.create(
|
||||
country_code='GB',
|
||||
payment_method='bank_transfer',
|
||||
display_name='Bank Transfer (UK)',
|
||||
is_enabled=True,
|
||||
sort_order=3,
|
||||
)
|
||||
PaymentMethodConfig.objects.create(
|
||||
country_code='IN',
|
||||
payment_method='local_wallet',
|
||||
display_name='UPI/Wallets',
|
||||
is_enabled=True,
|
||||
sort_order=4,
|
||||
)
|
||||
PaymentMethodConfig.objects.create(
|
||||
country_code='PK',
|
||||
payment_method='bank_transfer',
|
||||
display_name='Bank Transfer (Pakistan)',
|
||||
is_enabled=True,
|
||||
sort_order=5,
|
||||
)
|
||||
|
||||
# Disabled method (should not appear)
|
||||
PaymentMethodConfig.objects.create(
|
||||
country_code='*',
|
||||
payment_method='manual',
|
||||
display_name='Manual',
|
||||
is_enabled=False,
|
||||
sort_order=99,
|
||||
)
|
||||
|
||||
self.client = Client()
|
||||
|
||||
def test_filter_payment_methods_by_us(self):
|
||||
"""Test filtering for US country - should get only global methods"""
|
||||
response = self.client.get('/api/v1/billing/admin/payment-methods/?country=US')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
|
||||
self.assertTrue(data['success'])
|
||||
self.assertEqual(len(data['results']), 2) # Only stripe and paypal
|
||||
|
||||
methods = [m['type'] for m in data['results']]
|
||||
self.assertIn('stripe', methods)
|
||||
self.assertIn('paypal', methods)
|
||||
|
||||
def test_filter_payment_methods_by_gb(self):
|
||||
"""Test filtering for GB - should get global + GB-specific"""
|
||||
response = self.client.get('/api/v1/billing/admin/payment-methods/?country=GB')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
|
||||
self.assertTrue(data['success'])
|
||||
self.assertEqual(len(data['results']), 3) # stripe, paypal, bank_transfer(GB)
|
||||
|
||||
methods = [m['type'] for m in data['results']]
|
||||
self.assertIn('stripe', methods)
|
||||
self.assertIn('paypal', methods)
|
||||
self.assertIn('bank_transfer', methods)
|
||||
|
||||
def test_filter_payment_methods_by_in(self):
|
||||
"""Test filtering for IN - should get global + IN-specific"""
|
||||
response = self.client.get('/api/v1/billing/admin/payment-methods/?country=IN')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
|
||||
self.assertTrue(data['success'])
|
||||
self.assertEqual(len(data['results']), 3) # stripe, paypal, local_wallet(IN)
|
||||
|
||||
methods = [m['type'] for m in data['results']]
|
||||
self.assertIn('stripe', methods)
|
||||
self.assertIn('paypal', methods)
|
||||
self.assertIn('local_wallet', methods)
|
||||
|
||||
def test_disabled_methods_not_returned(self):
|
||||
"""Test that disabled payment methods are not included"""
|
||||
response = self.client.get('/api/v1/billing/admin/payment-methods/?country=*')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
|
||||
methods = [m['type'] for m in data['results']]
|
||||
self.assertNotIn('manual', methods) # Disabled method should not appear
|
||||
|
||||
def test_sort_order_respected(self):
|
||||
\"\"\"Test that payment methods are returned in sort_order\"\"\"
|
||||
response = self.client.get('/api/v1/billing/admin/payment-methods/?country=GB')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
|
||||
# Verify first method has lowest sort_order
|
||||
self.assertEqual(data['results'][0]['type'], 'stripe')
|
||||
self.assertEqual(data['results'][0]['sort_order'], 1)
|
||||
|
||||
def test_default_country_fallback(self):
|
||||
"""Test that missing country parameter defaults to global (*)\"\"\"\n response = self.client.get('/api/v1/billing/admin/payment-methods/')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()
|
||||
|
||||
self.assertTrue(data['success'])
|
||||
# Should get at least global methods
|
||||
methods = [m['type'] for m in data['results']]
|
||||
self.assertIn('stripe', methods)
|
||||
self.assertIn('paypal', methods)
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Integration tests for payment workflow
|
||||
"""
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
from decimal import Decimal
|
||||
from datetime import timedelta
|
||||
|
||||
from igny8_core.auth.models import Account, Plan, Subscription
|
||||
from igny8_core.business.billing.models import Invoice, Payment
|
||||
from igny8_core.business.billing.services.invoice_service import InvoiceService
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class PaymentWorkflowIntegrationTest(TestCase):
|
||||
"""Test complete payment workflow including invoice.subscription FK"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create test data"""
|
||||
# Create plan
|
||||
self.plan = Plan.objects.create(
|
||||
name='Test Plan',
|
||||
slug='test-plan',
|
||||
price=Decimal('29.00'),
|
||||
included_credits=1000,
|
||||
max_sites=5,
|
||||
)
|
||||
|
||||
# Create account
|
||||
self.account = Account.objects.create(
|
||||
name='Test Account',
|
||||
slug='test-account',
|
||||
status='pending_payment',
|
||||
billing_country='US',
|
||||
billing_email='test@example.com',
|
||||
)
|
||||
|
||||
# Create user
|
||||
self.user = User.objects.create_user(
|
||||
username='testuser',
|
||||
email='testuser@example.com',
|
||||
password='testpass123',
|
||||
account=self.account,
|
||||
)
|
||||
|
||||
# Create subscription
|
||||
billing_period_start = timezone.now()
|
||||
billing_period_end = billing_period_start + timedelta(days=30)
|
||||
|
||||
self.subscription = Subscription.objects.create(
|
||||
account=self.account,
|
||||
plan=self.plan,
|
||||
status='pending_payment',
|
||||
current_period_start=billing_period_start,
|
||||
current_period_end=billing_period_end,
|
||||
)
|
||||
|
||||
def test_invoice_subscription_fk_relationship(self):
|
||||
"""Test that invoice.subscription FK works correctly"""
|
||||
# Create invoice via service
|
||||
billing_period_start = timezone.now()
|
||||
billing_period_end = billing_period_start + timedelta(days=30)
|
||||
|
||||
invoice = InvoiceService.create_subscription_invoice(
|
||||
subscription=self.subscription,
|
||||
billing_period_start=billing_period_start,
|
||||
billing_period_end=billing_period_end,
|
||||
)
|
||||
|
||||
# Verify FK relationship
|
||||
self.assertIsNotNone(invoice.subscription)
|
||||
self.assertEqual(invoice.subscription.id, self.subscription.id)
|
||||
self.assertEqual(invoice.subscription.plan.id, self.plan.id)
|
||||
|
||||
# Verify can access subscription from invoice
|
||||
self.assertEqual(invoice.subscription.account, self.account)
|
||||
self.assertEqual(invoice.subscription.plan.name, 'Test Plan')
|
||||
|
||||
def test_payment_approval_with_subscription(self):
|
||||
"""Test payment approval workflow uses invoice.subscription"""
|
||||
# Create invoice
|
||||
billing_period_start = timezone.now()
|
||||
billing_period_end = billing_period_start + timedelta(days=30)
|
||||
|
||||
invoice = InvoiceService.create_subscription_invoice(
|
||||
subscription=self.subscription,
|
||||
billing_period_start=billing_period_start,
|
||||
billing_period_end=billing_period_end,
|
||||
)
|
||||
|
||||
# Create payment
|
||||
payment = Payment.objects.create(
|
||||
account=self.account,
|
||||
invoice=invoice,
|
||||
amount=invoice.total,
|
||||
currency='USD',
|
||||
status='pending_approval',
|
||||
payment_method='bank_transfer',
|
||||
manual_reference='TEST-REF-001',
|
||||
)
|
||||
|
||||
# Verify payment links to invoice which links to subscription
|
||||
self.assertIsNotNone(payment.invoice)
|
||||
self.assertIsNotNone(payment.invoice.subscription)
|
||||
self.assertEqual(payment.invoice.subscription.id, self.subscription.id)
|
||||
|
||||
# Simulate approval workflow
|
||||
payment.status = 'succeeded'
|
||||
payment.approved_by = self.user
|
||||
payment.approved_at = timezone.now()
|
||||
payment.save()
|
||||
|
||||
# Update related records
|
||||
invoice.status = 'paid'
|
||||
invoice.paid_at = timezone.now()
|
||||
invoice.save()
|
||||
|
||||
subscription = invoice.subscription
|
||||
subscription.status = 'active'
|
||||
subscription.save()
|
||||
|
||||
# Verify workflow completed successfully
|
||||
self.assertEqual(payment.status, 'succeeded')
|
||||
self.assertEqual(invoice.status, 'paid')
|
||||
self.assertEqual(subscription.status, 'active')
|
||||
self.assertEqual(subscription.plan.included_credits, 1000)
|
||||
|
||||
def test_subscription_dates_not_null_for_paid_plans(self):
|
||||
"""Test that subscription dates are set for paid plans"""
|
||||
self.assertIsNotNone(self.subscription.current_period_start)
|
||||
self.assertIsNotNone(self.subscription.current_period_end)
|
||||
|
||||
# Verify dates are in future
|
||||
self.assertGreater(self.subscription.current_period_end, self.subscription.current_period_start)
|
||||
|
||||
def test_invoice_currency_based_on_country(self):
|
||||
"""Test that invoice currency is set based on billing country"""
|
||||
# Test US -> USD
|
||||
self.account.billing_country = 'US'
|
||||
self.account.save()
|
||||
|
||||
billing_period_start = timezone.now()
|
||||
billing_period_end = billing_period_start + timedelta(days=30)
|
||||
|
||||
invoice_us = InvoiceService.create_subscription_invoice(
|
||||
subscription=self.subscription,
|
||||
billing_period_start=billing_period_start,
|
||||
billing_period_end=billing_period_end,
|
||||
)
|
||||
self.assertEqual(invoice_us.currency, 'USD')
|
||||
|
||||
# Test GB -> GBP
|
||||
self.account.billing_country = 'GB'
|
||||
self.account.save()
|
||||
|
||||
invoice_gb = InvoiceService.create_subscription_invoice(
|
||||
subscription=self.subscription,
|
||||
billing_period_start=billing_period_start,
|
||||
billing_period_end=billing_period_end,
|
||||
)
|
||||
self.assertEqual(invoice_gb.currency, 'GBP')
|
||||
|
||||
# Test IN -> INR
|
||||
self.account.billing_country = 'IN'
|
||||
self.account.save()
|
||||
|
||||
invoice_in = InvoiceService.create_subscription_invoice(
|
||||
subscription=self.subscription,
|
||||
billing_period_start=billing_period_start,
|
||||
billing_period_end=billing_period_end,
|
||||
)
|
||||
self.assertEqual(invoice_in.currency, 'INR')
|
||||
|
||||
def test_invoice_due_date_grace_period(self):
|
||||
"""Test that invoice due date uses grace period instead of billing_period_end"""
|
||||
billing_period_start = timezone.now()
|
||||
billing_period_end = billing_period_start + timedelta(days=30)
|
||||
|
||||
invoice = InvoiceService.create_subscription_invoice(
|
||||
subscription=self.subscription,
|
||||
billing_period_start=billing_period_start,
|
||||
billing_period_end=billing_period_end,
|
||||
)
|
||||
|
||||
# Verify due date is invoice_date + 7 days (grace period)
|
||||
expected_due_date = invoice.invoice_date + timedelta(days=7)
|
||||
self.assertEqual(invoice.due_date, expected_due_date)
|
||||
|
||||
# Verify it's NOT billing_period_end
|
||||
self.assertNotEqual(invoice.due_date, billing_period_end.date())
|
||||
@@ -1,133 +0,0 @@
|
||||
"""
|
||||
Tests for Phase 4 credit deduction
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from django.test import TestCase
|
||||
from igny8_core.business.content.models import Content
|
||||
from igny8_core.business.billing.services.credit_service import CreditService
|
||||
from igny8_core.business.billing.constants import CREDIT_COSTS
|
||||
from igny8_core.business.billing.exceptions import InsufficientCreditsError
|
||||
from igny8_core.api.tests.test_integration_base import IntegrationTestBase
|
||||
|
||||
|
||||
class Phase4CreditTests(IntegrationTestBase):
|
||||
"""Tests for Phase 4 credit deduction"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Set initial credits
|
||||
self.account.credits = 1000
|
||||
self.account.save()
|
||||
|
||||
def test_linking_deducts_correct_credits(self):
|
||||
"""Test that linking deducts correct credits"""
|
||||
cost = CreditService.get_credit_cost('linking')
|
||||
expected_cost = CREDIT_COSTS.get('linking', 0)
|
||||
|
||||
self.assertEqual(cost, expected_cost)
|
||||
self.assertEqual(cost, 8) # From constants
|
||||
|
||||
def test_optimization_deducts_correct_credits(self):
|
||||
"""Test that optimization deducts correct credits based on word count"""
|
||||
word_count = 500
|
||||
cost = CreditService.get_credit_cost('optimization', word_count)
|
||||
|
||||
# Should be 1 credit per 200 words, so 500 words = 3 credits (max(1, 1 * 500/200) = 3)
|
||||
expected = max(1, int(CREDIT_COSTS.get('optimization', 1) * (word_count / 200)))
|
||||
self.assertEqual(cost, expected)
|
||||
|
||||
def test_optimization_credits_per_entry_point(self):
|
||||
"""Test that optimization credits are same regardless of entry point"""
|
||||
word_count = 400
|
||||
|
||||
# All entry points should use same credit calculation
|
||||
cost = CreditService.get_credit_cost('optimization', word_count)
|
||||
|
||||
# 400 words = 2 credits (1 * 400/200)
|
||||
self.assertEqual(cost, 2)
|
||||
|
||||
@patch('igny8_core.business.billing.services.credit_service.CreditService.deduct_credits')
|
||||
def test_pipeline_deducts_credits_at_each_stage(self, mock_deduct):
|
||||
"""Test that pipeline deducts credits at each stage"""
|
||||
from igny8_core.business.content.services.content_pipeline_service import ContentPipelineService
|
||||
from igny8_core.business.linking.services.linker_service import LinkerService
|
||||
from igny8_core.business.optimization.services.optimizer_service import OptimizerService
|
||||
|
||||
content = Content.objects.create(
|
||||
account=self.account,
|
||||
site=self.site,
|
||||
sector=self.sector,
|
||||
title="Test",
|
||||
word_count=400,
|
||||
source='igny8'
|
||||
)
|
||||
|
||||
# Mock the services
|
||||
with patch.object(LinkerService, 'process') as mock_link, \
|
||||
patch.object(OptimizerService, 'optimize_from_writer') as mock_optimize:
|
||||
|
||||
mock_link.return_value = content
|
||||
mock_optimize.return_value = content
|
||||
|
||||
service = ContentPipelineService()
|
||||
service.process_writer_content(content.id)
|
||||
|
||||
# Should deduct credits for both linking and optimization
|
||||
self.assertGreater(mock_deduct.call_count, 0)
|
||||
|
||||
def test_insufficient_credits_blocks_linking(self):
|
||||
"""Test that insufficient credits blocks linking"""
|
||||
self.account.credits = 5 # Less than linking cost (8)
|
||||
self.account.save()
|
||||
|
||||
with self.assertRaises(InsufficientCreditsError):
|
||||
CreditService.check_credits(self.account, 'linking')
|
||||
|
||||
def test_insufficient_credits_blocks_optimization(self):
|
||||
"""Test that insufficient credits blocks optimization"""
|
||||
self.account.credits = 1 # Less than optimization cost for 500 words
|
||||
self.account.save()
|
||||
|
||||
with self.assertRaises(InsufficientCreditsError):
|
||||
CreditService.check_credits(self.account, 'optimization', 500)
|
||||
|
||||
def test_credit_deduction_logged(self):
|
||||
"""Test that credit deduction is logged"""
|
||||
from igny8_core.business.billing.models import CreditUsageLog
|
||||
|
||||
initial_credits = self.account.credits
|
||||
cost = CreditService.get_credit_cost('linking')
|
||||
|
||||
CreditService.deduct_credits_for_operation(
|
||||
account=self.account,
|
||||
operation_type='linking',
|
||||
description="Test linking"
|
||||
)
|
||||
|
||||
self.account.refresh_from_db()
|
||||
self.assertEqual(self.account.credits, initial_credits - cost)
|
||||
|
||||
# Check that usage log was created
|
||||
log = CreditUsageLog.objects.filter(
|
||||
account=self.account,
|
||||
operation_type='linking'
|
||||
).first()
|
||||
self.assertIsNotNone(log)
|
||||
|
||||
def test_batch_operations_deduct_multiple_credits(self):
|
||||
"""Test that batch operations deduct multiple credits"""
|
||||
initial_credits = self.account.credits
|
||||
linking_cost = CreditService.get_credit_cost('linking')
|
||||
|
||||
# Deduct for 3 linking operations
|
||||
for i in range(3):
|
||||
CreditService.deduct_credits_for_operation(
|
||||
account=self.account,
|
||||
operation_type='linking',
|
||||
description=f"Linking {i}"
|
||||
)
|
||||
|
||||
self.account.refresh_from_db()
|
||||
expected_credits = initial_credits - (linking_cost * 3)
|
||||
self.assertEqual(self.account.credits, expected_credits)
|
||||
|
||||
Reference in New Issue
Block a user