511 lines
26 KiB
Python
511 lines
26 KiB
Python
"""
|
||
AI Engine - Central orchestrator for all AI functions
|
||
"""
|
||
import logging
|
||
from typing import Dict, Any, Optional
|
||
from igny8_core.ai.base import BaseAIFunction
|
||
from igny8_core.ai.tracker import StepTracker, ProgressTracker, CostTracker, ConsoleStepTracker
|
||
from igny8_core.ai.ai_core import AICore
|
||
from igny8_core.ai.settings import get_model_config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AIEngine:
|
||
"""
|
||
Central orchestrator for all AI functions.
|
||
Manages lifecycle, progress, logging, retries, cost tracking.
|
||
"""
|
||
|
||
def __init__(self, celery_task=None, account=None):
|
||
self.task = celery_task
|
||
self.account = account
|
||
self.tracker = ProgressTracker(celery_task)
|
||
self.step_tracker = StepTracker('ai_engine') # For Celery progress callbacks
|
||
self.console_tracker = None # Will be initialized per function
|
||
self.cost_tracker = CostTracker()
|
||
|
||
def _get_input_description(self, function_name: str, payload: dict, count: int) -> str:
|
||
"""Get user-friendly input description"""
|
||
if function_name == 'auto_cluster':
|
||
return f"{count} keyword{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_ideas':
|
||
return f"{count} cluster{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_content':
|
||
return f"{count} task{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_images':
|
||
return f"{count} task{'s' if count != 1 else ''}"
|
||
return f"{count} item{'s' if count != 1 else ''}"
|
||
|
||
def _build_validation_message(self, function_name: str, payload: dict, count: int, input_description: str) -> str:
|
||
"""Build validation message with item names for better UX"""
|
||
if function_name == 'auto_cluster' and count > 0:
|
||
try:
|
||
from igny8_core.modules.planner.models import Keywords
|
||
ids = payload.get('ids', [])
|
||
keywords = Keywords.objects.filter(id__in=ids, account=self.account).values_list('keyword', flat=True)[:3]
|
||
keyword_list = list(keywords)
|
||
|
||
if len(keyword_list) > 0:
|
||
remaining = count - len(keyword_list)
|
||
if remaining > 0:
|
||
keywords_text = ', '.join(keyword_list)
|
||
return f"Validating {keywords_text} and {remaining} more keyword{'s' if remaining != 1 else ''}"
|
||
else:
|
||
keywords_text = ', '.join(keyword_list)
|
||
return f"Validating {keywords_text}"
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load keyword names for validation message: {e}")
|
||
|
||
# Fallback to simple count message
|
||
return f"Validating {input_description}"
|
||
|
||
def _get_prep_message(self, function_name: str, count: int, data: Any) -> str:
|
||
"""Get user-friendly prep message"""
|
||
if function_name == 'auto_cluster':
|
||
return f"Loading {count} keyword{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_ideas':
|
||
return f"Loading {count} cluster{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_content':
|
||
return f"Preparing {count} content idea{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_images':
|
||
return f"Extracting image prompts from {count} task{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_image_prompts':
|
||
# Extract max_images from data if available
|
||
if isinstance(data, list) and len(data) > 0:
|
||
max_images = data[0].get('max_images', 2)
|
||
total_images = 1 + max_images # 1 featured + max_images in-article
|
||
return f"Mapping Content for {total_images} Image Prompts"
|
||
elif isinstance(data, dict) and 'max_images' in data:
|
||
max_images = data.get('max_images', 2)
|
||
total_images = 1 + max_images
|
||
return f"Mapping Content for {total_images} Image Prompts"
|
||
return f"Mapping Content for Image Prompts"
|
||
elif function_name == 'generate_images_from_prompts':
|
||
# Extract image count from data
|
||
if isinstance(data, dict) and 'images' in data:
|
||
total_images = len(data.get('images', []))
|
||
return f"Preparing to generate {total_images} image{'s' if total_images != 1 else ''}"
|
||
return f"Preparing image generation queue"
|
||
return f"Preparing {count} item{'s' if count != 1 else ''}"
|
||
|
||
def _get_ai_call_message(self, function_name: str, count: int) -> str:
|
||
"""Get user-friendly AI call message"""
|
||
if function_name == 'auto_cluster':
|
||
return f"Grouping {count} keyword{'s' if count != 1 else ''} into clusters"
|
||
elif function_name == 'generate_ideas':
|
||
return f"Generating content ideas for {count} cluster{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_content':
|
||
return f"Writing article{'s' if count != 1 else ''} with AI"
|
||
elif function_name == 'generate_images':
|
||
return f"Creating image{'s' if count != 1 else ''} with AI"
|
||
elif function_name == 'generate_images_from_prompts':
|
||
return f"Generating images with AI"
|
||
return f"Processing with AI"
|
||
|
||
def _get_parse_message(self, function_name: str) -> str:
|
||
"""Get user-friendly parse message"""
|
||
if function_name == 'auto_cluster':
|
||
return "Organizing clusters"
|
||
elif function_name == 'generate_ideas':
|
||
return "Structuring outlines"
|
||
elif function_name == 'generate_content':
|
||
return "Formatting content"
|
||
elif function_name == 'generate_images':
|
||
return "Processing images"
|
||
return "Processing results"
|
||
|
||
def _get_parse_message_with_count(self, function_name: str, count: int) -> str:
|
||
"""Get user-friendly parse message with count"""
|
||
if function_name == 'auto_cluster':
|
||
return f"{count} cluster{'s' if count != 1 else ''} created"
|
||
elif function_name == 'generate_ideas':
|
||
return f"{count} idea{'s' if count != 1 else ''} created"
|
||
elif function_name == 'generate_content':
|
||
return f"{count} article{'s' if count != 1 else ''} created"
|
||
elif function_name == 'generate_images':
|
||
return f"{count} image{'s' if count != 1 else ''} created"
|
||
elif function_name == 'generate_image_prompts':
|
||
# Count is total prompts, in-article is count - 1 (subtract featured)
|
||
in_article_count = max(0, count - 1)
|
||
if in_article_count > 0:
|
||
return f"Writing {in_article_count} In‑article Image Prompts"
|
||
return "Writing In‑article Image Prompts"
|
||
elif function_name == 'generate_images_from_prompts':
|
||
return f"{count} image{'s' if count != 1 else ''} generated"
|
||
return f"{count} item{'s' if count != 1 else ''} processed"
|
||
|
||
def _get_save_message(self, function_name: str, count: int) -> str:
|
||
"""Get user-friendly save message"""
|
||
if function_name == 'auto_cluster':
|
||
return f"Saving {count} cluster{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_ideas':
|
||
return f"Saving {count} idea{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_content':
|
||
return f"Saving {count} article{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_images':
|
||
return f"Saving {count} image{'s' if count != 1 else ''}"
|
||
elif function_name == 'generate_image_prompts':
|
||
# Count is total prompts created
|
||
return f"Assigning {count} Prompts to Dedicated Slots"
|
||
elif function_name == 'generate_images_from_prompts':
|
||
return f"Saving {count} image{'s' if count != 1 else ''}"
|
||
return f"Saving {count} item{'s' if count != 1 else ''}"
|
||
|
||
def execute(self, fn: BaseAIFunction, payload: dict) -> dict:
|
||
"""
|
||
Unified execution pipeline for all AI functions.
|
||
|
||
Phases with improved percentage mapping:
|
||
- INIT (0-10%): Validation & preparation
|
||
- PREP (10-25%): Data loading & prompt building
|
||
- AI_CALL (25-70%): API call to provider (longest phase)
|
||
- PARSE (70-85%): Response parsing
|
||
- SAVE (85-98%): Database operations
|
||
- DONE (98-100%): Finalization
|
||
"""
|
||
function_name = fn.get_name()
|
||
self.step_tracker.function_name = function_name
|
||
|
||
# Initialize console tracker for logging (Stage 3 requirement)
|
||
self.console_tracker = ConsoleStepTracker(function_name)
|
||
self.console_tracker.init(f"Starting {function_name} execution")
|
||
|
||
try:
|
||
# Phase 1: INIT - Validation & Setup (0-10%)
|
||
# Extract input data for user-friendly messages
|
||
ids = payload.get('ids', [])
|
||
input_count = len(ids) if ids else 0
|
||
input_description = self._get_input_description(function_name, payload, input_count)
|
||
|
||
self.console_tracker.prep(f"Validating {input_description}")
|
||
validated = fn.validate(payload, self.account)
|
||
if not validated['valid']:
|
||
self.console_tracker.error('ValidationError', validated['error'])
|
||
return self._handle_error(validated['error'], fn)
|
||
|
||
# Build validation message with keyword names for auto_cluster
|
||
validation_message = self._build_validation_message(function_name, payload, input_count, input_description)
|
||
|
||
self.console_tracker.prep("Validation complete")
|
||
self.step_tracker.add_request_step("INIT", "success", validation_message)
|
||
self.tracker.update("INIT", 10, validation_message, meta=self.step_tracker.get_meta())
|
||
|
||
# Phase 2: PREP - Data Loading & Prompt Building (10-25%)
|
||
data = fn.prepare(payload, self.account)
|
||
if isinstance(data, (list, tuple)):
|
||
data_count = len(data)
|
||
elif isinstance(data, dict):
|
||
# Check for cluster_data (for generate_ideas) or keywords (for auto_cluster)
|
||
if 'cluster_data' in data:
|
||
data_count = len(data['cluster_data'])
|
||
elif 'keywords' in data:
|
||
data_count = len(data['keywords'])
|
||
else:
|
||
data_count = data.get('count', input_count)
|
||
else:
|
||
data_count = input_count
|
||
|
||
prep_message = self._get_prep_message(function_name, data_count, data)
|
||
self.console_tracker.prep(prep_message)
|
||
|
||
# For image generation, build_prompt returns placeholder
|
||
# Actual processing happens in save_output
|
||
if function_name == 'generate_images_from_prompts':
|
||
prompt = "Image generation queue prepared"
|
||
else:
|
||
prompt = fn.build_prompt(data, self.account)
|
||
self.console_tracker.prep(f"Prompt built: {len(prompt)} characters")
|
||
|
||
self.step_tracker.add_request_step("PREP", "success", prep_message)
|
||
self.tracker.update("PREP", 25, prep_message, meta=self.step_tracker.get_meta())
|
||
|
||
# Phase 3: AI_CALL - Provider API Call (25-70%)
|
||
# For image generation, AI calls happen in save_output, so skip this phase
|
||
if function_name == 'generate_images_from_prompts':
|
||
# Skip AI_CALL phase - processing happens in save_output
|
||
raw_response = {'content': 'Image generation queue ready'}
|
||
parsed = {'processed': True}
|
||
else:
|
||
ai_core = AICore(account=self.account)
|
||
function_name = fn.get_name()
|
||
|
||
# Generate function_id for tracking (ai-{function_name}-01)
|
||
# Normalize underscores to hyphens to match frontend tracking IDs
|
||
function_id_base = function_name.replace('_', '-')
|
||
function_id = f"ai-{function_id_base}-01-desktop"
|
||
|
||
# Get model config from settings (Stage 4 requirement)
|
||
# Pass account to read model from IntegrationSettings
|
||
model_config = get_model_config(function_name, account=self.account)
|
||
model = model_config.get('model')
|
||
|
||
# Read model straight from IntegrationSettings for visibility
|
||
model_from_integration = None
|
||
if self.account:
|
||
try:
|
||
from igny8_core.modules.system.models import IntegrationSettings
|
||
openai_settings = IntegrationSettings.objects.filter(
|
||
integration_type='openai',
|
||
account=self.account,
|
||
is_active=True
|
||
).first()
|
||
if openai_settings and openai_settings.config:
|
||
model_from_integration = openai_settings.config.get('model')
|
||
except Exception as integration_error:
|
||
logger.warning(
|
||
"[AIEngine] Unable to read model from IntegrationSettings: %s",
|
||
integration_error,
|
||
exc_info=True,
|
||
)
|
||
|
||
# Debug logging: Show model configuration (console only, not in step tracker)
|
||
logger.info(f"[AIEngine] Model Configuration for {function_name}:")
|
||
logger.info(f" - Model from get_model_config: {model}")
|
||
logger.info(f" - Full model_config: {model_config}")
|
||
self.console_tracker.ai_call(f"Model from settings: {model_from_integration or 'Not set'}")
|
||
self.console_tracker.ai_call(f"Model selected for request: {model or 'default'}")
|
||
self.console_tracker.ai_call(f"Calling {model or 'default'} model with {len(prompt)} char prompt")
|
||
self.console_tracker.ai_call(f"Function ID: {function_id}")
|
||
|
||
# Track AI call start with user-friendly message
|
||
ai_call_message = self._get_ai_call_message(function_name, data_count)
|
||
self.step_tracker.add_response_step("AI_CALL", "success", ai_call_message)
|
||
self.tracker.update("AI_CALL", 50, ai_call_message, meta=self.step_tracker.get_meta())
|
||
|
||
try:
|
||
# Use centralized run_ai_request() with console logging (Stage 2 & 3 requirement)
|
||
# Pass console_tracker for unified logging
|
||
raw_response = ai_core.run_ai_request(
|
||
prompt=prompt,
|
||
model=model,
|
||
max_tokens=model_config.get('max_tokens'),
|
||
temperature=model_config.get('temperature'),
|
||
response_format=model_config.get('response_format'),
|
||
function_name=function_name,
|
||
function_id=function_id, # Pass function_id for tracking
|
||
tracker=self.console_tracker # Pass console tracker for logging
|
||
)
|
||
except Exception as e:
|
||
error_msg = f"AI call failed: {str(e)}"
|
||
logger.error(f"Exception during AI call: {error_msg}", exc_info=True)
|
||
return self._handle_error(error_msg, fn)
|
||
|
||
if raw_response.get('error'):
|
||
error_msg = raw_response.get('error', 'Unknown AI error')
|
||
logger.error(f"AI call returned error: {error_msg}")
|
||
return self._handle_error(error_msg, fn)
|
||
|
||
if not raw_response.get('content'):
|
||
error_msg = "AI call returned no content"
|
||
logger.error(error_msg)
|
||
return self._handle_error(error_msg, fn)
|
||
|
||
# Track cost
|
||
self.cost_tracker.record(
|
||
function_name=function_name,
|
||
cost=raw_response.get('cost', 0),
|
||
tokens=raw_response.get('total_tokens', 0),
|
||
model=raw_response.get('model')
|
||
)
|
||
|
||
# Update AI_CALL step with results
|
||
self.step_tracker.response_steps[-1] = {
|
||
**self.step_tracker.response_steps[-1],
|
||
'message': f"Received {raw_response.get('total_tokens', 0)} tokens, Cost: ${raw_response.get('cost', 0):.6f}",
|
||
'duration': raw_response.get('duration')
|
||
}
|
||
self.tracker.update("AI_CALL", 70, f"AI response received ({raw_response.get('total_tokens', 0)} tokens)", meta=self.step_tracker.get_meta())
|
||
|
||
# Phase 4: PARSE - Response Parsing (70-85%)
|
||
try:
|
||
parse_message = self._get_parse_message(function_name)
|
||
self.console_tracker.parse(parse_message)
|
||
response_content = raw_response.get('content', '')
|
||
parsed = fn.parse_response(response_content, self.step_tracker)
|
||
|
||
if isinstance(parsed, (list, tuple)):
|
||
parsed_count = len(parsed)
|
||
elif isinstance(parsed, dict):
|
||
# Check if it's a content dict (has 'content' field) or a result dict (has 'count')
|
||
if 'content' in parsed:
|
||
parsed_count = 1 # Single content item
|
||
else:
|
||
parsed_count = parsed.get('count', 1)
|
||
else:
|
||
parsed_count = 1
|
||
|
||
# Update parse message with count for better UX
|
||
parse_message = self._get_parse_message_with_count(function_name, parsed_count)
|
||
|
||
self.console_tracker.parse(f"Successfully parsed {parsed_count} items from response")
|
||
self.step_tracker.add_response_step("PARSE", "success", parse_message)
|
||
self.tracker.update("PARSE", 85, parse_message, meta=self.step_tracker.get_meta())
|
||
except Exception as parse_error:
|
||
error_msg = f"Failed to parse AI response: {str(parse_error)}"
|
||
logger.error(f"AIEngine: {error_msg}", exc_info=True)
|
||
logger.error(f"AIEngine: Response content was: {response_content[:500] if response_content else 'None'}...")
|
||
return self._handle_error(error_msg, fn)
|
||
|
||
# Phase 5: SAVE - Database Operations (85-98%)
|
||
# Pass step_tracker to save_output so it can add validation steps
|
||
save_result = fn.save_output(parsed, data, self.account, self.tracker, step_tracker=self.step_tracker)
|
||
clusters_created = save_result.get('clusters_created', 0)
|
||
keywords_updated = save_result.get('keywords_updated', 0)
|
||
count = save_result.get('count', 0)
|
||
|
||
# Use user-friendly save message based on function type
|
||
if clusters_created:
|
||
save_msg = f"Saving {clusters_created} cluster{'s' if clusters_created != 1 else ''}"
|
||
elif count:
|
||
save_msg = self._get_save_message(function_name, count)
|
||
else:
|
||
save_msg = self._get_save_message(function_name, data_count)
|
||
|
||
self.console_tracker.save(save_msg)
|
||
self.step_tracker.add_request_step("SAVE", "success", save_msg)
|
||
self.tracker.update("SAVE", 98, save_msg, meta=self.step_tracker.get_meta())
|
||
|
||
# Store save_msg for use in DONE phase
|
||
final_save_msg = save_msg
|
||
|
||
# Track credit usage after successful save
|
||
if self.account and raw_response:
|
||
try:
|
||
from igny8_core.modules.billing.services import CreditService
|
||
from igny8_core.modules.billing.models import CreditUsageLog
|
||
|
||
# Calculate credits used (based on tokens or fixed cost)
|
||
credits_used = self._calculate_credits_for_clustering(
|
||
keyword_count=len(data.get('keywords', [])) if isinstance(data, dict) else len(data) if isinstance(data, list) else 1,
|
||
tokens=raw_response.get('total_tokens', 0),
|
||
cost=raw_response.get('cost', 0)
|
||
)
|
||
|
||
# Log credit usage (don't deduct from account.credits, just log)
|
||
CreditUsageLog.objects.create(
|
||
account=self.account,
|
||
operation_type='clustering',
|
||
credits_used=credits_used,
|
||
cost_usd=raw_response.get('cost'),
|
||
model_used=raw_response.get('model', ''),
|
||
tokens_input=raw_response.get('tokens_input', 0),
|
||
tokens_output=raw_response.get('tokens_output', 0),
|
||
related_object_type='cluster',
|
||
metadata={
|
||
'clusters_created': clusters_created,
|
||
'keywords_updated': keywords_updated,
|
||
'function_name': function_name
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to log credit usage: {e}", exc_info=True)
|
||
|
||
# Phase 6: DONE - Finalization (98-100%)
|
||
success_msg = f"Task completed: {final_save_msg}" if 'final_save_msg' in locals() else "Task completed successfully"
|
||
self.console_tracker.done(success_msg)
|
||
self.step_tracker.add_request_step("DONE", "success", "Task completed successfully")
|
||
self.tracker.update("DONE", 100, "Task complete!", meta=self.step_tracker.get_meta())
|
||
|
||
# Log to database
|
||
self._log_to_database(fn, payload, parsed, save_result)
|
||
|
||
return {
|
||
'success': True,
|
||
**save_result,
|
||
'request_steps': self.step_tracker.request_steps,
|
||
'response_steps': self.step_tracker.response_steps,
|
||
'cost': self.cost_tracker.get_total(),
|
||
'tokens': self.cost_tracker.get_total_tokens()
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in AIEngine.execute for {function_name}: {str(e)}", exc_info=True)
|
||
return self._handle_error(str(e), fn, exc_info=True)
|
||
|
||
def _handle_error(self, error: str, fn: BaseAIFunction = None, exc_info=False):
|
||
"""Centralized error handling"""
|
||
function_name = fn.get_name() if fn else 'unknown'
|
||
|
||
# Log to console tracker if available (Stage 3 requirement)
|
||
if self.console_tracker:
|
||
error_type = type(error).__name__ if isinstance(error, Exception) else 'Error'
|
||
self.console_tracker.error(error_type, str(error), exception=error if isinstance(error, Exception) else None)
|
||
|
||
self.step_tracker.add_request_step("Error", "error", error, error=error)
|
||
|
||
error_meta = {
|
||
'error': error,
|
||
'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error',
|
||
**self.step_tracker.get_meta()
|
||
}
|
||
self.tracker.error(error, meta=error_meta)
|
||
|
||
if exc_info:
|
||
logger.error(f"Error in {function_name}: {error}", exc_info=True)
|
||
else:
|
||
logger.error(f"Error in {function_name}: {error}")
|
||
|
||
self._log_to_database(fn, None, None, None, error=error)
|
||
|
||
return {
|
||
'success': False,
|
||
'error': error,
|
||
'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error',
|
||
'request_steps': self.step_tracker.request_steps,
|
||
'response_steps': self.step_tracker.response_steps
|
||
}
|
||
|
||
def _log_to_database(
|
||
self,
|
||
fn: BaseAIFunction = None,
|
||
payload: dict = None,
|
||
parsed: Any = None,
|
||
save_result: dict = None,
|
||
error: str = None
|
||
):
|
||
"""Log to unified ai_task_logs table"""
|
||
try:
|
||
from igny8_core.ai.models import AITaskLog
|
||
|
||
# Only log if account exists (AITaskLog requires account)
|
||
if not self.account:
|
||
logger.warning("Cannot log AI task - no account available")
|
||
return
|
||
|
||
AITaskLog.objects.create(
|
||
task_id=self.task.request.id if self.task else None,
|
||
function_name=fn.get_name() if fn else None,
|
||
account=self.account,
|
||
phase=self.tracker.current_phase,
|
||
message=self.tracker.current_message,
|
||
status='error' if error else 'success',
|
||
duration=self.tracker.get_duration(),
|
||
cost=self.cost_tracker.get_total(),
|
||
tokens=self.cost_tracker.get_total_tokens(),
|
||
request_steps=self.step_tracker.request_steps,
|
||
response_steps=self.step_tracker.response_steps,
|
||
error=error,
|
||
payload=payload,
|
||
result=save_result
|
||
)
|
||
except Exception as e:
|
||
# Don't fail the task if logging fails
|
||
logger.warning(f"Failed to log to database: {e}")
|
||
|
||
def _calculate_credits_for_clustering(self, keyword_count, tokens, cost):
|
||
"""Calculate credits used for clustering operation"""
|
||
# Use plan's cost per request if available, otherwise calculate from tokens
|
||
if self.account and hasattr(self.account, 'plan') and self.account.plan:
|
||
plan = self.account.plan
|
||
# Check if plan has ai_cost_per_request config
|
||
if hasattr(plan, 'ai_cost_per_request') and plan.ai_cost_per_request:
|
||
cluster_cost = plan.ai_cost_per_request.get('cluster', 0)
|
||
if cluster_cost:
|
||
return int(cluster_cost)
|
||
|
||
# Fallback: 1 credit per 30 keywords (minimum 1)
|
||
credits = max(1, int(keyword_count / 30))
|
||
return credits
|
||
|