350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""
|
|
Unified Celery task entrypoint for all AI functions
|
|
"""
|
|
import logging
|
|
from celery import shared_task
|
|
from igny8_core.ai.engine import AIEngine
|
|
from igny8_core.ai.registry import get_function_instance
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@shared_task(bind=True, max_retries=3)
|
|
def run_ai_task(self, function_name: str, payload: dict, account_id: int = None):
|
|
"""
|
|
Single Celery entrypoint for all AI functions.
|
|
Dynamically loads and executes the requested function.
|
|
|
|
Args:
|
|
function_name: Name of the AI function (e.g., 'auto_cluster')
|
|
payload: Function-specific payload
|
|
account_id: Account ID for account isolation
|
|
"""
|
|
logger.info("=" * 80)
|
|
logger.info(f"run_ai_task STARTED: {function_name}")
|
|
logger.info(f" - Task ID: {self.request.id}")
|
|
logger.info(f" - Function: {function_name}")
|
|
logger.info(f" - Account ID: {account_id}")
|
|
logger.info(f" - Payload keys: {list(payload.keys())}")
|
|
logger.info("=" * 80)
|
|
|
|
try:
|
|
# Get account
|
|
account = None
|
|
if account_id:
|
|
from igny8_core.auth.models import Account
|
|
try:
|
|
account = Account.objects.get(id=account_id)
|
|
except Account.DoesNotExist:
|
|
logger.warning(f"Account {account_id} not found")
|
|
|
|
# Get function from registry
|
|
fn = get_function_instance(function_name)
|
|
if not fn:
|
|
error_msg = f'Function {function_name} not found in registry'
|
|
logger.error(error_msg)
|
|
return {
|
|
'success': False,
|
|
'error': error_msg
|
|
}
|
|
|
|
# Create engine and execute
|
|
engine = AIEngine(celery_task=self, account=account)
|
|
result = engine.execute(fn, payload)
|
|
|
|
logger.info("=" * 80)
|
|
logger.info(f"run_ai_task COMPLETED: {function_name}")
|
|
logger.info(f" - Success: {result.get('success')}")
|
|
if not result.get('success'):
|
|
logger.error(f" - Error: {result.get('error')}")
|
|
logger.info("=" * 80)
|
|
|
|
# If execution failed, update state and return error (don't raise to avoid serialization issues)
|
|
if not result.get('success'):
|
|
error_msg = result.get('error', 'Task execution failed')
|
|
error_type = result.get('error_type', 'ExecutionError')
|
|
# Update task state with error details
|
|
error_meta = {
|
|
'error': error_msg,
|
|
'error_type': error_type,
|
|
'function_name': function_name,
|
|
'phase': result.get('phase', 'ERROR'),
|
|
'percentage': 0,
|
|
'message': f'Error: {error_msg}',
|
|
'request_steps': result.get('request_steps', []),
|
|
'response_steps': result.get('response_steps', [])
|
|
}
|
|
try:
|
|
self.update_state(
|
|
state='FAILURE',
|
|
meta=error_meta
|
|
)
|
|
except Exception as update_err:
|
|
logger.warning(f"Failed to update task state: {update_err}")
|
|
|
|
# Return error result - Celery will mark as FAILURE based on state
|
|
# Don't raise exception to avoid serialization issues
|
|
return {
|
|
'success': False,
|
|
'error': error_msg,
|
|
'error_type': error_type,
|
|
**error_meta
|
|
}
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
error_type = type(e).__name__
|
|
error_msg = str(e)
|
|
|
|
logger.error("=" * 80)
|
|
logger.error(f"run_ai_task FAILED: {function_name}")
|
|
logger.error(f" - Error: {error_type}: {error_msg}")
|
|
logger.error("=" * 80, exc_info=True)
|
|
|
|
# Update task state with error details (don't raise to avoid serialization issues)
|
|
error_meta = {
|
|
'error': error_msg,
|
|
'error_type': error_type,
|
|
'function_name': function_name,
|
|
'phase': 'ERROR',
|
|
'percentage': 0,
|
|
'message': f'Error: {error_msg}'
|
|
}
|
|
try:
|
|
self.update_state(
|
|
state='FAILURE',
|
|
meta=error_meta
|
|
)
|
|
except Exception as update_err:
|
|
logger.warning(f"Failed to update task state: {update_err}")
|
|
|
|
# Return error result - don't raise to avoid Celery serialization issues
|
|
return {
|
|
'success': False,
|
|
'error': error_msg,
|
|
'error_type': error_type,
|
|
'function_name': function_name,
|
|
**error_meta
|
|
}
|
|
|
|
|
|
@shared_task(bind=True, name='igny8_core.ai.tasks.process_image_generation_queue')
|
|
def process_image_generation_queue(self, image_ids: list, account_id: int = None, content_id: int = None):
|
|
"""
|
|
Process image generation queue sequentially (one image at a time)
|
|
Updates Celery task meta with progress for each image
|
|
"""
|
|
from typing import List
|
|
from igny8_core.modules.writer.models import Images, Content
|
|
from igny8_core.modules.system.models import IntegrationSettings
|
|
from igny8_core.ai.ai_core import AICore
|
|
from igny8_core.utils.prompt_registry import PromptRegistry
|
|
|
|
logger.info("=" * 80)
|
|
logger.info(f"process_image_generation_queue STARTED")
|
|
logger.info(f" - Task ID: {self.request.id}")
|
|
logger.info(f" - Image IDs: {image_ids}")
|
|
logger.info(f" - Account ID: {account_id}")
|
|
logger.info(f" - Content ID: {content_id}")
|
|
logger.info("=" * 80)
|
|
|
|
account = None
|
|
if account_id:
|
|
from igny8_core.auth.models import Account
|
|
try:
|
|
account = Account.objects.get(id=account_id)
|
|
except Account.DoesNotExist:
|
|
logger.error(f"Account {account_id} not found")
|
|
return {'success': False, 'error': 'Account not found'}
|
|
|
|
# Initialize progress tracking
|
|
total_images = len(image_ids)
|
|
completed = 0
|
|
failed = 0
|
|
results = []
|
|
|
|
# Get image generation settings from IntegrationSettings
|
|
try:
|
|
image_settings = IntegrationSettings.objects.get(
|
|
account=account,
|
|
integration_type='image_generation',
|
|
is_active=True
|
|
)
|
|
config = image_settings.config or {}
|
|
provider = config.get('provider', 'openai')
|
|
model = config.get('model', 'dall-e-3')
|
|
image_type = config.get('image_type', 'realistic')
|
|
image_format = config.get('image_format', 'webp')
|
|
desktop_enabled = config.get('desktop_enabled', True)
|
|
mobile_enabled = config.get('mobile_enabled', True)
|
|
except IntegrationSettings.DoesNotExist:
|
|
logger.error("Image generation settings not found")
|
|
return {'success': False, 'error': 'Image generation settings not found'}
|
|
|
|
# Get provider API key
|
|
try:
|
|
if provider == 'openai':
|
|
provider_settings = IntegrationSettings.objects.get(
|
|
account=account,
|
|
integration_type='openai',
|
|
is_active=True
|
|
)
|
|
api_key = provider_settings.config.get('api_key') if provider_settings.config else None
|
|
elif provider == 'runware':
|
|
provider_settings = IntegrationSettings.objects.get(
|
|
account=account,
|
|
integration_type='runware',
|
|
is_active=True
|
|
)
|
|
api_key = provider_settings.config.get('api_key') if provider_settings.config else None
|
|
else:
|
|
return {'success': False, 'error': f'Unknown provider: {provider}'}
|
|
|
|
if not api_key:
|
|
return {'success': False, 'error': f'{provider} API key not configured'}
|
|
except IntegrationSettings.DoesNotExist:
|
|
return {'success': False, 'error': f'{provider} integration not found'}
|
|
|
|
# Get prompt templates
|
|
try:
|
|
image_prompt_template = PromptRegistry.get_image_prompt_template(account)
|
|
negative_prompt = PromptRegistry.get_negative_prompt(account) if provider == 'runware' else None
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get prompt templates: {e}, using fallback")
|
|
image_prompt_template = "{image_prompt}"
|
|
negative_prompt = None
|
|
|
|
# Initialize AICore
|
|
ai_core = AICore(account=account)
|
|
|
|
# Process each image sequentially
|
|
for index, image_id in enumerate(image_ids, 1):
|
|
try:
|
|
# Update task meta: current image processing
|
|
self.update_state(
|
|
state='PROGRESS',
|
|
meta={
|
|
'current_image': index,
|
|
'total_images': total_images,
|
|
'completed': completed,
|
|
'failed': failed,
|
|
'status': 'processing',
|
|
'current_image_id': image_id,
|
|
'results': results
|
|
}
|
|
)
|
|
|
|
# Load image record
|
|
try:
|
|
image = Images.objects.get(id=image_id, account=account)
|
|
except Images.DoesNotExist:
|
|
logger.error(f"Image {image_id} not found")
|
|
results.append({
|
|
'image_id': image_id,
|
|
'status': 'failed',
|
|
'error': 'Image record not found'
|
|
})
|
|
failed += 1
|
|
continue
|
|
|
|
# Check if prompt exists
|
|
if not image.prompt:
|
|
logger.warning(f"Image {image_id} has no prompt")
|
|
results.append({
|
|
'image_id': image_id,
|
|
'status': 'failed',
|
|
'error': 'No prompt found'
|
|
})
|
|
failed += 1
|
|
continue
|
|
|
|
# Get content for prompt formatting
|
|
content = image.content
|
|
if not content:
|
|
logger.warning(f"Image {image_id} has no content")
|
|
results.append({
|
|
'image_id': image_id,
|
|
'status': 'failed',
|
|
'error': 'No content associated'
|
|
})
|
|
failed += 1
|
|
continue
|
|
|
|
# Format prompt using template
|
|
try:
|
|
formatted_prompt = image_prompt_template.format(
|
|
post_title=content.title or content.meta_title or f"Content #{content.id}",
|
|
image_prompt=image.prompt,
|
|
image_type=image_type
|
|
)
|
|
except Exception as e:
|
|
# Fallback to simple prompt
|
|
logger.warning(f"Prompt template formatting failed: {e}, using fallback")
|
|
formatted_prompt = f"{image.prompt}, {image_type} style"
|
|
|
|
# Generate image
|
|
logger.info(f"Generating image {index}/{total_images} (ID: {image_id})")
|
|
result = ai_core.generate_image(
|
|
prompt=formatted_prompt,
|
|
provider=provider,
|
|
model=model,
|
|
size='1024x1024',
|
|
api_key=api_key,
|
|
negative_prompt=negative_prompt,
|
|
function_name='generate_images_from_prompts'
|
|
)
|
|
|
|
# Check for errors
|
|
if result.get('error'):
|
|
logger.error(f"Image generation failed for {image_id}: {result.get('error')}")
|
|
# Update image record: failed
|
|
image.status = 'failed'
|
|
image.save(update_fields=['status'])
|
|
|
|
results.append({
|
|
'image_id': image_id,
|
|
'status': 'failed',
|
|
'error': result.get('error')
|
|
})
|
|
failed += 1
|
|
else:
|
|
logger.info(f"Image generation successful for {image_id}")
|
|
# Update image record: success
|
|
image.image_url = result.get('url')
|
|
image.status = 'generated'
|
|
image.save(update_fields=['image_url', 'status'])
|
|
|
|
results.append({
|
|
'image_id': image_id,
|
|
'status': 'completed',
|
|
'image_url': result.get('url'),
|
|
'revised_prompt': result.get('revised_prompt')
|
|
})
|
|
completed += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing image {image_id}: {str(e)}", exc_info=True)
|
|
results.append({
|
|
'image_id': image_id,
|
|
'status': 'failed',
|
|
'error': str(e)
|
|
})
|
|
failed += 1
|
|
|
|
# Final state
|
|
logger.info("=" * 80)
|
|
logger.info(f"process_image_generation_queue COMPLETED")
|
|
logger.info(f" - Total: {total_images}")
|
|
logger.info(f" - Completed: {completed}")
|
|
logger.info(f" - Failed: {failed}")
|
|
logger.info("=" * 80)
|
|
|
|
return {
|
|
'success': True,
|
|
'total_images': total_images,
|
|
'completed': completed,
|
|
'failed': failed,
|
|
'results': results
|
|
}
|