Files
igny8/backend/igny8_core/ai/engine.py
2025-11-11 19:14:04 +00:00

486 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI Engine - Central orchestrator for all AI functions
"""
import logging
from typing import Dict, Any, Optional
from igny8_core.ai.base import BaseAIFunction
from igny8_core.ai.tracker import StepTracker, ProgressTracker, CostTracker, ConsoleStepTracker
from igny8_core.ai.ai_core import AICore
from igny8_core.ai.settings import get_model_config
logger = logging.getLogger(__name__)
class AIEngine:
"""
Central orchestrator for all AI functions.
Manages lifecycle, progress, logging, retries, cost tracking.
"""
def __init__(self, celery_task=None, account=None):
self.task = celery_task
self.account = account
self.tracker = ProgressTracker(celery_task)
self.step_tracker = StepTracker('ai_engine') # For Celery progress callbacks
self.console_tracker = None # Will be initialized per function
self.cost_tracker = CostTracker()
def _get_input_description(self, function_name: str, payload: dict, count: int) -> str:
"""Get user-friendly input description"""
if function_name == 'auto_cluster':
return f"{count} keyword{'s' if count != 1 else ''}"
elif function_name == 'generate_ideas':
return f"{count} cluster{'s' if count != 1 else ''}"
elif function_name == 'generate_content':
return f"{count} task{'s' if count != 1 else ''}"
elif function_name == 'generate_images':
return f"{count} task{'s' if count != 1 else ''}"
return f"{count} item{'s' if count != 1 else ''}"
def _build_validation_message(self, function_name: str, payload: dict, count: int, input_description: str) -> str:
"""Build validation message with item names for better UX"""
if function_name == 'auto_cluster' and count > 0:
try:
from igny8_core.modules.planner.models import Keywords
ids = payload.get('ids', [])
keywords = Keywords.objects.filter(id__in=ids, account=self.account).values_list('keyword', flat=True)[:3]
keyword_list = list(keywords)
if len(keyword_list) > 0:
remaining = count - len(keyword_list)
if remaining > 0:
keywords_text = ', '.join(keyword_list)
return f"Validating {keywords_text} and {remaining} more keyword{'s' if remaining != 1 else ''}"
else:
keywords_text = ', '.join(keyword_list)
return f"Validating {keywords_text}"
except Exception as e:
logger.warning(f"Failed to load keyword names for validation message: {e}")
# Fallback to simple count message
return f"Validating {input_description}"
def _get_prep_message(self, function_name: str, count: int, data: Any) -> str:
"""Get user-friendly prep message"""
if function_name == 'auto_cluster':
return f"Loading {count} keyword{'s' if count != 1 else ''}"
elif function_name == 'generate_ideas':
return f"Loading {count} cluster{'s' if count != 1 else ''}"
elif function_name == 'generate_content':
return f"Preparing {count} content idea{'s' if count != 1 else ''}"
elif function_name == 'generate_images':
return f"Extracting image prompts from {count} task{'s' if count != 1 else ''}"
elif function_name == 'generate_image_prompts':
# Extract max_images from data if available
if isinstance(data, list) and len(data) > 0:
max_images = data[0].get('max_images', 2)
total_images = 1 + max_images # 1 featured + max_images in-article
return f"Mapping Content for {total_images} Image Prompts"
elif isinstance(data, dict) and 'max_images' in data:
max_images = data.get('max_images', 2)
total_images = 1 + max_images
return f"Mapping Content for {total_images} Image Prompts"
return f"Mapping Content for Image Prompts"
return f"Preparing {count} item{'s' if count != 1 else ''}"
def _get_ai_call_message(self, function_name: str, count: int) -> str:
"""Get user-friendly AI call message"""
if function_name == 'auto_cluster':
return f"Grouping {count} keyword{'s' if count != 1 else ''} into clusters"
elif function_name == 'generate_ideas':
return f"Generating content ideas for {count} cluster{'s' if count != 1 else ''}"
elif function_name == 'generate_content':
return f"Writing article{'s' if count != 1 else ''} with AI"
elif function_name == 'generate_images':
return f"Creating image{'s' if count != 1 else ''} with AI"
return f"Processing with AI"
def _get_parse_message(self, function_name: str) -> str:
"""Get user-friendly parse message"""
if function_name == 'auto_cluster':
return "Organizing clusters"
elif function_name == 'generate_ideas':
return "Structuring outlines"
elif function_name == 'generate_content':
return "Formatting content"
elif function_name == 'generate_images':
return "Processing images"
return "Processing results"
def _get_parse_message_with_count(self, function_name: str, count: int) -> str:
"""Get user-friendly parse message with count"""
if function_name == 'auto_cluster':
return f"{count} cluster{'s' if count != 1 else ''} created"
elif function_name == 'generate_ideas':
return f"{count} idea{'s' if count != 1 else ''} created"
elif function_name == 'generate_content':
return f"{count} article{'s' if count != 1 else ''} created"
elif function_name == 'generate_images':
return f"{count} image{'s' if count != 1 else ''} created"
elif function_name == 'generate_image_prompts':
# Count is total prompts, in-article is count - 1 (subtract featured)
in_article_count = max(0, count - 1)
if in_article_count > 0:
return f"Writing {in_article_count} Inarticle Image Prompts"
return "Writing Inarticle Image Prompts"
return f"{count} item{'s' if count != 1 else ''} processed"
def _get_save_message(self, function_name: str, count: int) -> str:
"""Get user-friendly save message"""
if function_name == 'auto_cluster':
return f"Saving {count} cluster{'s' if count != 1 else ''}"
elif function_name == 'generate_ideas':
return f"Saving {count} idea{'s' if count != 1 else ''}"
elif function_name == 'generate_content':
return f"Saving {count} article{'s' if count != 1 else ''}"
elif function_name == 'generate_images':
return f"Saving {count} image{'s' if count != 1 else ''}"
elif function_name == 'generate_image_prompts':
# Count is total prompts created
return f"Assigning {count} Prompts to Dedicated Slots"
return f"Saving {count} item{'s' if count != 1 else ''}"
def execute(self, fn: BaseAIFunction, payload: dict) -> dict:
"""
Unified execution pipeline for all AI functions.
Phases with improved percentage mapping:
- INIT (0-10%): Validation & preparation
- PREP (10-25%): Data loading & prompt building
- AI_CALL (25-70%): API call to provider (longest phase)
- PARSE (70-85%): Response parsing
- SAVE (85-98%): Database operations
- DONE (98-100%): Finalization
"""
function_name = fn.get_name()
self.step_tracker.function_name = function_name
# Initialize console tracker for logging (Stage 3 requirement)
self.console_tracker = ConsoleStepTracker(function_name)
self.console_tracker.init(f"Starting {function_name} execution")
try:
# Phase 1: INIT - Validation & Setup (0-10%)
# Extract input data for user-friendly messages
ids = payload.get('ids', [])
input_count = len(ids) if ids else 0
input_description = self._get_input_description(function_name, payload, input_count)
self.console_tracker.prep(f"Validating {input_description}")
validated = fn.validate(payload, self.account)
if not validated['valid']:
self.console_tracker.error('ValidationError', validated['error'])
return self._handle_error(validated['error'], fn)
# Build validation message with keyword names for auto_cluster
validation_message = self._build_validation_message(function_name, payload, input_count, input_description)
self.console_tracker.prep("Validation complete")
self.step_tracker.add_request_step("INIT", "success", validation_message)
self.tracker.update("INIT", 10, validation_message, meta=self.step_tracker.get_meta())
# Phase 2: PREP - Data Loading & Prompt Building (10-25%)
data = fn.prepare(payload, self.account)
if isinstance(data, (list, tuple)):
data_count = len(data)
elif isinstance(data, dict):
# Check for cluster_data (for generate_ideas) or keywords (for auto_cluster)
if 'cluster_data' in data:
data_count = len(data['cluster_data'])
elif 'keywords' in data:
data_count = len(data['keywords'])
else:
data_count = data.get('count', input_count)
else:
data_count = input_count
prep_message = self._get_prep_message(function_name, data_count, data)
self.console_tracker.prep(prep_message)
prompt = fn.build_prompt(data, self.account)
self.console_tracker.prep(f"Prompt built: {len(prompt)} characters")
self.step_tracker.add_request_step("PREP", "success", prep_message)
self.tracker.update("PREP", 25, prep_message, meta=self.step_tracker.get_meta())
# Phase 3: AI_CALL - Provider API Call (25-70%)
ai_core = AICore(account=self.account)
function_name = fn.get_name()
# Generate function_id for tracking (ai-{function_name}-01)
# Normalize underscores to hyphens to match frontend tracking IDs
function_id_base = function_name.replace('_', '-')
function_id = f"ai-{function_id_base}-01-desktop"
# Get model config from settings (Stage 4 requirement)
# Pass account to read model from IntegrationSettings
model_config = get_model_config(function_name, account=self.account)
model = model_config.get('model')
# Read model straight from IntegrationSettings for visibility
model_from_integration = None
if self.account:
try:
from igny8_core.modules.system.models import IntegrationSettings
openai_settings = IntegrationSettings.objects.filter(
integration_type='openai',
account=self.account,
is_active=True
).first()
if openai_settings and openai_settings.config:
model_from_integration = openai_settings.config.get('model')
except Exception as integration_error:
logger.warning(
"[AIEngine] Unable to read model from IntegrationSettings: %s",
integration_error,
exc_info=True,
)
# Debug logging: Show model configuration (console only, not in step tracker)
logger.info(f"[AIEngine] Model Configuration for {function_name}:")
logger.info(f" - Model from get_model_config: {model}")
logger.info(f" - Full model_config: {model_config}")
self.console_tracker.ai_call(f"Model from settings: {model_from_integration or 'Not set'}")
self.console_tracker.ai_call(f"Model selected for request: {model or 'default'}")
self.console_tracker.ai_call(f"Calling {model or 'default'} model with {len(prompt)} char prompt")
self.console_tracker.ai_call(f"Function ID: {function_id}")
# Track AI call start with user-friendly message
ai_call_message = self._get_ai_call_message(function_name, data_count)
self.step_tracker.add_response_step("AI_CALL", "success", ai_call_message)
self.tracker.update("AI_CALL", 50, ai_call_message, meta=self.step_tracker.get_meta())
try:
# Use centralized run_ai_request() with console logging (Stage 2 & 3 requirement)
# Pass console_tracker for unified logging
raw_response = ai_core.run_ai_request(
prompt=prompt,
model=model,
max_tokens=model_config.get('max_tokens'),
temperature=model_config.get('temperature'),
response_format=model_config.get('response_format'),
function_name=function_name,
function_id=function_id, # Pass function_id for tracking
tracker=self.console_tracker # Pass console tracker for logging
)
except Exception as e:
error_msg = f"AI call failed: {str(e)}"
logger.error(f"Exception during AI call: {error_msg}", exc_info=True)
return self._handle_error(error_msg, fn)
if raw_response.get('error'):
error_msg = raw_response.get('error', 'Unknown AI error')
logger.error(f"AI call returned error: {error_msg}")
return self._handle_error(error_msg, fn)
if not raw_response.get('content'):
error_msg = "AI call returned no content"
logger.error(error_msg)
return self._handle_error(error_msg, fn)
# Track cost
self.cost_tracker.record(
function_name=function_name,
cost=raw_response.get('cost', 0),
tokens=raw_response.get('total_tokens', 0),
model=raw_response.get('model')
)
# Update AI_CALL step with results
self.step_tracker.response_steps[-1] = {
**self.step_tracker.response_steps[-1],
'message': f"Received {raw_response.get('total_tokens', 0)} tokens, Cost: ${raw_response.get('cost', 0):.6f}",
'duration': raw_response.get('duration')
}
self.tracker.update("AI_CALL", 70, f"AI response received ({raw_response.get('total_tokens', 0)} tokens)", meta=self.step_tracker.get_meta())
# Phase 4: PARSE - Response Parsing (70-85%)
try:
parse_message = self._get_parse_message(function_name)
self.console_tracker.parse(parse_message)
response_content = raw_response.get('content', '')
parsed = fn.parse_response(response_content, self.step_tracker)
if isinstance(parsed, (list, tuple)):
parsed_count = len(parsed)
elif isinstance(parsed, dict):
# Check if it's a content dict (has 'content' field) or a result dict (has 'count')
if 'content' in parsed:
parsed_count = 1 # Single content item
else:
parsed_count = parsed.get('count', 1)
else:
parsed_count = 1
# Update parse message with count for better UX
parse_message = self._get_parse_message_with_count(function_name, parsed_count)
self.console_tracker.parse(f"Successfully parsed {parsed_count} items from response")
self.step_tracker.add_response_step("PARSE", "success", parse_message)
self.tracker.update("PARSE", 85, parse_message, meta=self.step_tracker.get_meta())
except Exception as parse_error:
error_msg = f"Failed to parse AI response: {str(parse_error)}"
logger.error(f"AIEngine: {error_msg}", exc_info=True)
logger.error(f"AIEngine: Response content was: {response_content[:500] if response_content else 'None'}...")
return self._handle_error(error_msg, fn)
# Phase 5: SAVE - Database Operations (85-98%)
# Pass step_tracker to save_output so it can add validation steps
save_result = fn.save_output(parsed, data, self.account, self.tracker, step_tracker=self.step_tracker)
clusters_created = save_result.get('clusters_created', 0)
keywords_updated = save_result.get('keywords_updated', 0)
count = save_result.get('count', 0)
# Use user-friendly save message based on function type
if clusters_created:
save_msg = f"Saving {clusters_created} cluster{'s' if clusters_created != 1 else ''}"
elif count:
save_msg = self._get_save_message(function_name, count)
else:
save_msg = self._get_save_message(function_name, data_count)
self.console_tracker.save(save_msg)
self.step_tracker.add_request_step("SAVE", "success", save_msg)
self.tracker.update("SAVE", 98, save_msg, meta=self.step_tracker.get_meta())
# Store save_msg for use in DONE phase
final_save_msg = save_msg
# Track credit usage after successful save
if self.account and raw_response:
try:
from igny8_core.modules.billing.services import CreditService
from igny8_core.modules.billing.models import CreditUsageLog
# Calculate credits used (based on tokens or fixed cost)
credits_used = self._calculate_credits_for_clustering(
keyword_count=len(data.get('keywords', [])) if isinstance(data, dict) else len(data) if isinstance(data, list) else 1,
tokens=raw_response.get('total_tokens', 0),
cost=raw_response.get('cost', 0)
)
# Log credit usage (don't deduct from account.credits, just log)
CreditUsageLog.objects.create(
account=self.account,
operation_type='clustering',
credits_used=credits_used,
cost_usd=raw_response.get('cost'),
model_used=raw_response.get('model', ''),
tokens_input=raw_response.get('tokens_input', 0),
tokens_output=raw_response.get('tokens_output', 0),
related_object_type='cluster',
metadata={
'clusters_created': clusters_created,
'keywords_updated': keywords_updated,
'function_name': function_name
}
)
except Exception as e:
logger.warning(f"Failed to log credit usage: {e}", exc_info=True)
# Phase 6: DONE - Finalization (98-100%)
success_msg = f"Task completed: {final_save_msg}" if 'final_save_msg' in locals() else "Task completed successfully"
self.console_tracker.done(success_msg)
self.step_tracker.add_request_step("DONE", "success", "Task completed successfully")
self.tracker.update("DONE", 100, "Task complete!", meta=self.step_tracker.get_meta())
# Log to database
self._log_to_database(fn, payload, parsed, save_result)
return {
'success': True,
**save_result,
'request_steps': self.step_tracker.request_steps,
'response_steps': self.step_tracker.response_steps,
'cost': self.cost_tracker.get_total(),
'tokens': self.cost_tracker.get_total_tokens()
}
except Exception as e:
logger.error(f"Error in AIEngine.execute for {function_name}: {str(e)}", exc_info=True)
return self._handle_error(str(e), fn, exc_info=True)
def _handle_error(self, error: str, fn: BaseAIFunction = None, exc_info=False):
"""Centralized error handling"""
function_name = fn.get_name() if fn else 'unknown'
# Log to console tracker if available (Stage 3 requirement)
if self.console_tracker:
error_type = type(error).__name__ if isinstance(error, Exception) else 'Error'
self.console_tracker.error(error_type, str(error), exception=error if isinstance(error, Exception) else None)
self.step_tracker.add_request_step("Error", "error", error, error=error)
error_meta = {
'error': error,
'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error',
**self.step_tracker.get_meta()
}
self.tracker.error(error, meta=error_meta)
if exc_info:
logger.error(f"Error in {function_name}: {error}", exc_info=True)
else:
logger.error(f"Error in {function_name}: {error}")
self._log_to_database(fn, None, None, None, error=error)
return {
'success': False,
'error': error,
'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error',
'request_steps': self.step_tracker.request_steps,
'response_steps': self.step_tracker.response_steps
}
def _log_to_database(
self,
fn: BaseAIFunction = None,
payload: dict = None,
parsed: Any = None,
save_result: dict = None,
error: str = None
):
"""Log to unified ai_task_logs table"""
try:
from igny8_core.ai.models import AITaskLog
# Only log if account exists (AITaskLog requires account)
if not self.account:
logger.warning("Cannot log AI task - no account available")
return
AITaskLog.objects.create(
task_id=self.task.request.id if self.task else None,
function_name=fn.get_name() if fn else None,
account=self.account,
phase=self.tracker.current_phase,
message=self.tracker.current_message,
status='error' if error else 'success',
duration=self.tracker.get_duration(),
cost=self.cost_tracker.get_total(),
tokens=self.cost_tracker.get_total_tokens(),
request_steps=self.step_tracker.request_steps,
response_steps=self.step_tracker.response_steps,
error=error,
payload=payload,
result=save_result
)
except Exception as e:
# Don't fail the task if logging fails
logger.warning(f"Failed to log to database: {e}")
def _calculate_credits_for_clustering(self, keyword_count, tokens, cost):
"""Calculate credits used for clustering operation"""
# Use plan's cost per request if available, otherwise calculate from tokens
if self.account and hasattr(self.account, 'plan') and self.account.plan:
plan = self.account.plan
# Check if plan has ai_cost_per_request config
if hasattr(plan, 'ai_cost_per_request') and plan.ai_cost_per_request:
cluster_cost = plan.ai_cost_per_request.get('cluster', 0)
if cluster_cost:
return int(cluster_cost)
# Fallback: 1 credit per 30 keywords (minimum 1)
credits = max(1, int(keyword_count / 30))
return credits