""" AI Engine - Central orchestrator for all AI functions """ import logging from typing import Dict, Any, Optional from igny8_core.ai.base import BaseAIFunction from igny8_core.ai.tracker import StepTracker, ProgressTracker, CostTracker from igny8_core.ai.ai_core import AICore from igny8_core.ai.settings import get_model_config logger = logging.getLogger(__name__) class AIEngine: """ Central orchestrator for all AI functions. Manages lifecycle, progress, logging, retries, cost tracking. """ def __init__(self, celery_task=None, account=None): self.task = celery_task self.account = account self.tracker = ProgressTracker(celery_task) self.step_tracker = StepTracker('ai_engine') # For Celery progress callbacks self.cost_tracker = CostTracker() def _get_input_description(self, function_name: str, payload: dict, count: int) -> str: """Get user-friendly input description""" if function_name == 'auto_cluster': return f"{count} keyword{'s' if count != 1 else ''}" elif function_name == 'generate_ideas': return f"{count} cluster{'s' if count != 1 else ''}" elif function_name == 'generate_content': return f"{count} task{'s' if count != 1 else ''}" elif function_name == 'generate_images': return f"{count} task{'s' if count != 1 else ''}" return f"{count} item{'s' if count != 1 else ''}" def _build_validation_message(self, function_name: str, payload: dict, count: int, input_description: str) -> str: """Build validation message with item names for better UX""" if function_name == 'auto_cluster' and count > 0: try: from igny8_core.modules.planner.models import Keywords ids = payload.get('ids', []) keywords = Keywords.objects.filter(id__in=ids, account=self.account).values_list('keyword', flat=True)[:3] keyword_list = list(keywords) if len(keyword_list) > 0: remaining = count - len(keyword_list) if remaining > 0: keywords_text = ', '.join(keyword_list) return f"Validating {keywords_text} and {remaining} more keyword{'s' if remaining != 1 else ''}" else: keywords_text = ', '.join(keyword_list) return f"Validating {keywords_text}" except Exception as e: logger.warning(f"Failed to load keyword names for validation message: {e}") # Fallback to simple count message return f"Validating {input_description}" def _get_prep_message(self, function_name: str, count: int, data: Any) -> str: """Get user-friendly prep message""" if function_name == 'auto_cluster': return f"Loading {count} keyword{'s' if count != 1 else ''}" elif function_name == 'generate_ideas': return f"Loading {count} cluster{'s' if count != 1 else ''}" elif function_name == 'generate_content': return f"Preparing {count} content idea{'s' if count != 1 else ''}" elif function_name == 'generate_images': return f"Extracting image prompts from {count} task{'s' if count != 1 else ''}" elif function_name == 'generate_image_prompts': # Extract max_images from data if available if isinstance(data, list) and len(data) > 0: max_images = data[0].get('max_images', 2) total_images = 1 + max_images # 1 featured + max_images in-article return f"Mapping Content for {total_images} Image Prompts" elif isinstance(data, dict) and 'max_images' in data: max_images = data.get('max_images', 2) total_images = 1 + max_images return f"Mapping Content for {total_images} Image Prompts" return f"Mapping Content for Image Prompts" return f"Preparing {count} item{'s' if count != 1 else ''}" def _get_ai_call_message(self, function_name: str, count: int) -> str: """Get user-friendly AI call message""" if function_name == 'auto_cluster': return f"Grouping {count} keyword{'s' if count != 1 else ''} into clusters" elif function_name == 'generate_ideas': return f"Generating content ideas for {count} cluster{'s' if count != 1 else ''}" elif function_name == 'generate_content': return f"Writing article{'s' if count != 1 else ''} with AI" elif function_name == 'generate_images': return f"Creating image{'s' if count != 1 else ''} with AI" return f"Processing with AI" def _get_parse_message(self, function_name: str) -> str: """Get user-friendly parse message""" if function_name == 'auto_cluster': return "Organizing clusters" elif function_name == 'generate_ideas': return "Structuring outlines" elif function_name == 'generate_content': return "Formatting content" elif function_name == 'generate_images': return "Processing images" return "Processing results" def _get_parse_message_with_count(self, function_name: str, count: int) -> str: """Get user-friendly parse message with count""" if function_name == 'auto_cluster': return f"{count} cluster{'s' if count != 1 else ''} created" elif function_name == 'generate_ideas': return f"{count} idea{'s' if count != 1 else ''} created" elif function_name == 'generate_content': return f"{count} article{'s' if count != 1 else ''} created" elif function_name == 'generate_images': return f"{count} image{'s' if count != 1 else ''} created" elif function_name == 'generate_image_prompts': # Count is total prompts, in-article is count - 1 (subtract featured) in_article_count = max(0, count - 1) if in_article_count > 0: return f"Writing {in_article_count} In‑article Image Prompts" return "Writing In‑article Image Prompts" return f"{count} item{'s' if count != 1 else ''} processed" def _get_save_message(self, function_name: str, count: int) -> str: """Get user-friendly save message""" if function_name == 'auto_cluster': return f"Saving {count} cluster{'s' if count != 1 else ''}" elif function_name == 'generate_ideas': return f"Saving {count} idea{'s' if count != 1 else ''}" elif function_name == 'generate_content': return f"Saving {count} article{'s' if count != 1 else ''}" elif function_name == 'generate_images': return f"Saving {count} image{'s' if count != 1 else ''}" elif function_name == 'generate_image_prompts': # Count is total prompts created return f"Assigning {count} Prompts to Dedicated Slots" return f"Saving {count} item{'s' if count != 1 else ''}" def execute(self, fn: BaseAIFunction, payload: dict) -> dict: """ Unified execution pipeline for all AI functions. Phases with improved percentage mapping: - INIT (0-10%): Validation & preparation - PREP (10-25%): Data loading & prompt building - AI_CALL (25-70%): API call to provider (longest phase) - PARSE (70-85%): Response parsing - SAVE (85-98%): Database operations - DONE (98-100%): Finalization """ function_name = fn.get_name() self.step_tracker.function_name = function_name try: # Phase 1: INIT - Validation & Setup (0-10%) # Extract input data for user-friendly messages ids = payload.get('ids', []) input_count = len(ids) if ids else 0 input_description = self._get_input_description(function_name, payload, input_count) validated = fn.validate(payload, self.account) if not validated['valid']: return self._handle_error(validated['error'], fn) # Build validation message with keyword names for auto_cluster validation_message = self._build_validation_message(function_name, payload, input_count, input_description) self.step_tracker.add_request_step("INIT", "success", validation_message) self.tracker.update("INIT", 10, validation_message, meta=self.step_tracker.get_meta()) # Phase 2: PREP - Data Loading & Prompt Building (10-25%) data = fn.prepare(payload, self.account) if isinstance(data, (list, tuple)): data_count = len(data) elif isinstance(data, dict): # Check for cluster_data (for generate_ideas) or keywords (for auto_cluster) if 'cluster_data' in data: data_count = len(data['cluster_data']) elif 'keywords' in data: data_count = len(data['keywords']) else: data_count = data.get('count', input_count) else: data_count = input_count prep_message = self._get_prep_message(function_name, data_count, data) prompt = fn.build_prompt(data, self.account) self.step_tracker.add_request_step("PREP", "success", prep_message) self.tracker.update("PREP", 25, prep_message, meta=self.step_tracker.get_meta()) # Phase 2.5: CREDIT CHECK - Check credits before AI call (25%) # Bypass for system accounts and developers (handled in CreditService) if self.account: try: from igny8_core.modules.billing.services import CreditService from igny8_core.modules.billing.exceptions import InsufficientCreditsError # Map function name to operation type operation_type = self._get_operation_type(function_name) # Calculate estimated cost estimated_amount = self._get_estimated_amount(function_name, data, payload) # Check credits BEFORE AI call (CreditService handles developer/system account bypass) # Note: user=None for Celery tasks, but CreditService checks account.is_system_account() and developer users CreditService.check_credits(self.account, operation_type, estimated_amount, user=None) logger.info(f"[AIEngine] Credit check passed: {operation_type}, estimated amount: {estimated_amount}") except InsufficientCreditsError as e: error_msg = str(e) error_type = 'InsufficientCreditsError' logger.error(f"[AIEngine] {error_msg}") return self._handle_error(error_msg, fn, error_type=error_type) except Exception as e: logger.warning(f"[AIEngine] Failed to check credits: {e}", exc_info=True) # Don't fail the operation if credit check fails (for backward compatibility) # Phase 3: AI_CALL - Provider API Call (25-70%) # Validate account exists before proceeding if not self.account: error_msg = "Account is required for AI function execution" logger.error(f"[AIEngine] {error_msg}") return self._handle_error(error_msg, fn) ai_core = AICore(account=self.account) function_name = fn.get_name() # Generate function_id for tracking (ai-{function_name}-01) # Normalize underscores to hyphens to match frontend tracking IDs function_id_base = function_name.replace('_', '-') function_id = f"ai-{function_id_base}-01-desktop" # Get model config from settings (requires account) # This will raise ValueError if IntegrationSettings not configured try: model_config = get_model_config(function_name, account=self.account) model = model_config.get('model') except ValueError as e: # IntegrationSettings not configured or model missing error_msg = str(e) error_type = 'ConfigurationError' logger.error(f"[AIEngine] {error_msg}") return self._handle_error(error_msg, fn, error_type=error_type) except Exception as e: # Other unexpected errors error_msg = f"Failed to get model configuration: {str(e)}" error_type = type(e).__name__ logger.error(f"[AIEngine] {error_msg}", exc_info=True) return self._handle_error(error_msg, fn, error_type=error_type) # Debug logging: Show model configuration (console only, not in step tracker) logger.info(f"[AIEngine] Model Configuration for {function_name}:") logger.info(f" - Model from get_model_config: {model}") logger.info(f" - Full model_config: {model_config}") # Track AI call start with user-friendly message ai_call_message = self._get_ai_call_message(function_name, data_count) self.step_tracker.add_response_step("AI_CALL", "success", ai_call_message) self.tracker.update("AI_CALL", 50, ai_call_message, meta=self.step_tracker.get_meta()) try: # Use centralized run_ai_request() raw_response = ai_core.run_ai_request( prompt=prompt, model=model, max_tokens=model_config.get('max_tokens'), temperature=model_config.get('temperature'), response_format=model_config.get('response_format'), function_name=function_name, function_id=function_id # Pass function_id for tracking ) except Exception as e: error_msg = f"AI call failed: {str(e)}" logger.error(f"Exception during AI call: {error_msg}", exc_info=True) return self._handle_error(error_msg, fn) if raw_response.get('error'): error_msg = raw_response.get('error', 'Unknown AI error') logger.error(f"AI call returned error: {error_msg}") return self._handle_error(error_msg, fn) if not raw_response.get('content'): error_msg = "AI call returned no content" logger.error(error_msg) return self._handle_error(error_msg, fn) # Track cost self.cost_tracker.record( function_name=function_name, cost=raw_response.get('cost', 0), tokens=raw_response.get('total_tokens', 0), model=raw_response.get('model') ) # Update AI_CALL step with results self.step_tracker.response_steps[-1] = { **self.step_tracker.response_steps[-1], 'message': f"Received {raw_response.get('total_tokens', 0)} tokens, Cost: ${raw_response.get('cost', 0):.6f}", 'duration': raw_response.get('duration') } self.tracker.update("AI_CALL", 70, f"AI response received ({raw_response.get('total_tokens', 0)} tokens)", meta=self.step_tracker.get_meta()) # Phase 4: PARSE - Response Parsing (70-85%) try: parse_message = self._get_parse_message(function_name) response_content = raw_response.get('content', '') parsed = fn.parse_response(response_content, self.step_tracker) if isinstance(parsed, (list, tuple)): parsed_count = len(parsed) elif isinstance(parsed, dict): # Check if it's a content dict (has 'content' field) or a result dict (has 'count') if 'content' in parsed: parsed_count = 1 # Single content item else: parsed_count = parsed.get('count', 1) else: parsed_count = 1 # Update parse message with count for better UX parse_message = self._get_parse_message_with_count(function_name, parsed_count) self.step_tracker.add_response_step("PARSE", "success", parse_message) self.tracker.update("PARSE", 85, parse_message, meta=self.step_tracker.get_meta()) except Exception as parse_error: error_msg = f"Failed to parse AI response: {str(parse_error)}" logger.error(f"AIEngine: {error_msg}", exc_info=True) logger.error(f"AIEngine: Response content was: {response_content[:500] if response_content else 'None'}...") return self._handle_error(error_msg, fn) # Phase 5: SAVE - Database Operations (85-98%) save_result = fn.save_output(parsed, data, self.account, self.tracker, step_tracker=self.step_tracker) clusters_created = save_result.get('clusters_created', 0) keywords_updated = save_result.get('keywords_updated', 0) count = save_result.get('count', 0) # Use user-friendly save message based on function type if clusters_created: save_msg = f"Saving {clusters_created} cluster{'s' if clusters_created != 1 else ''}" elif count: save_msg = self._get_save_message(function_name, count) else: save_msg = self._get_save_message(function_name, data_count) self.step_tracker.add_request_step("SAVE", "success", save_msg) self.tracker.update("SAVE", 98, save_msg, meta=self.step_tracker.get_meta()) # Store save_msg for use in DONE phase final_save_msg = save_msg # Phase 5.5: DEDUCT CREDITS - Deduct credits after successful save if self.account and raw_response: try: from igny8_core.modules.billing.services import CreditService from igny8_core.modules.billing.exceptions import InsufficientCreditsError # Map function name to operation type operation_type = self._get_operation_type(function_name) # Calculate actual amount based on results actual_amount = self._get_actual_amount(function_name, save_result, parsed, data) # Deduct credits using the new convenience method CreditService.deduct_credits_for_operation( account=self.account, operation_type=operation_type, amount=actual_amount, cost_usd=raw_response.get('cost'), model_used=raw_response.get('model', ''), tokens_input=raw_response.get('tokens_input', 0), tokens_output=raw_response.get('tokens_output', 0), related_object_type=self._get_related_object_type(function_name), related_object_id=save_result.get('id') or save_result.get('cluster_id') or save_result.get('task_id'), metadata={ 'function_name': function_name, 'clusters_created': clusters_created, 'keywords_updated': keywords_updated, 'count': count, **save_result } ) logger.info(f"[AIEngine] Credits deducted: {operation_type}, amount: {actual_amount}") except InsufficientCreditsError as e: # This shouldn't happen since we checked before, but log it logger.error(f"[AIEngine] Insufficient credits during deduction: {e}") except Exception as e: logger.warning(f"[AIEngine] Failed to deduct credits: {e}", exc_info=True) # Don't fail the operation if credit deduction fails (for backward compatibility) # Phase 6: DONE - Finalization (98-100%) success_msg = f"Task completed: {final_save_msg}" if 'final_save_msg' in locals() else "Task completed successfully" self.step_tracker.add_request_step("DONE", "success", "Task completed successfully") self.tracker.update("DONE", 100, "Task complete!", meta=self.step_tracker.get_meta()) # Log to database self._log_to_database(fn, payload, parsed, save_result) return { 'success': True, **save_result, 'request_steps': self.step_tracker.request_steps, 'response_steps': self.step_tracker.response_steps, 'cost': self.cost_tracker.get_total(), 'tokens': self.cost_tracker.get_total_tokens() } except Exception as e: error_msg = str(e) error_type = type(e).__name__ logger.error(f"Error in AIEngine.execute for {function_name}: {error_msg}", exc_info=True) return self._handle_error(error_msg, fn, exc_info=True, error_type=error_type) def _handle_error(self, error: str, fn: BaseAIFunction = None, exc_info=False, error_type: str = None): """Centralized error handling""" function_name = fn.get_name() if fn else 'unknown' # Determine error type if error_type: final_error_type = error_type elif isinstance(error, Exception): final_error_type = type(error).__name__ else: final_error_type = 'Error' self.step_tracker.add_request_step("Error", "error", error, error=error) error_meta = { 'error': error, 'error_type': final_error_type, **self.step_tracker.get_meta() } self.tracker.error(error, meta=error_meta) if exc_info: logger.error(f"Error in {function_name}: {error}", exc_info=True) else: logger.error(f"Error in {function_name}: {error}") self._log_to_database(fn, None, None, None, error=error) return { 'success': False, 'error': error, 'error_type': final_error_type, 'request_steps': self.step_tracker.request_steps, 'response_steps': self.step_tracker.response_steps } def _log_to_database( self, fn: BaseAIFunction = None, payload: dict = None, parsed: Any = None, save_result: dict = None, error: str = None ): """Log to unified ai_task_logs table""" try: from igny8_core.ai.models import AITaskLog # Only log if account exists (AITaskLog requires account) if not self.account: logger.warning("Cannot log AI task - no account available") return AITaskLog.objects.create( task_id=self.task.request.id if self.task else None, function_name=fn.get_name() if fn else None, account=self.account, phase=self.tracker.current_phase, message=self.tracker.current_message, status='error' if error else 'success', duration=self.tracker.get_duration(), cost=self.cost_tracker.get_total(), tokens=self.cost_tracker.get_total_tokens(), request_steps=self.step_tracker.request_steps, response_steps=self.step_tracker.response_steps, error=error, payload=payload, result=save_result ) except Exception as e: # Don't fail the task if logging fails logger.warning(f"Failed to log to database: {e}") def _get_operation_type(self, function_name): """Map function name to operation type for credit system""" mapping = { 'auto_cluster': 'clustering', 'generate_ideas': 'idea_generation', 'generate_content': 'content_generation', 'generate_image_prompts': 'image_prompt_extraction', 'generate_images': 'image_generation', } return mapping.get(function_name, function_name) def _get_estimated_amount(self, function_name, data, payload): """Get estimated amount for credit calculation (before operation)""" if function_name == 'generate_content': # Estimate word count from task or default if isinstance(data, dict): return data.get('estimated_word_count', 1000) return 1000 # Default estimate elif function_name == 'generate_images': # Count images to generate if isinstance(payload, dict): image_ids = payload.get('image_ids', []) return len(image_ids) if image_ids else 1 return 1 elif function_name == 'generate_ideas': # Count clusters if isinstance(data, dict) and 'cluster_data' in data: return len(data['cluster_data']) return 1 # For fixed cost operations (clustering, image_prompt_extraction), return None return None def _get_actual_amount(self, function_name, save_result, parsed, data): """Get actual amount for credit calculation (after operation)""" if function_name == 'generate_content': # Get actual word count from saved content if isinstance(save_result, dict): word_count = save_result.get('word_count') if word_count: return word_count # Fallback: estimate from parsed content if isinstance(parsed, dict) and 'content' in parsed: content = parsed['content'] return len(content.split()) if isinstance(content, str) else 1000 return 1000 elif function_name == 'generate_images': # Count successfully generated images count = save_result.get('count', 0) if count > 0: return count return 1 elif function_name == 'generate_ideas': # Count ideas generated count = save_result.get('count', 0) if count > 0: return count return 1 # For fixed cost operations, return None return None def _get_related_object_type(self, function_name): """Get related object type for credit logging""" mapping = { 'auto_cluster': 'cluster', 'generate_ideas': 'content_idea', 'generate_content': 'content', 'generate_image_prompts': 'image', 'generate_images': 'image', } return mapping.get(function_name, 'unknown')