""" AI Engine - Central orchestrator for all AI functions """ import logging from typing import Dict, Any, Optional from igny8_core.ai.base import BaseAIFunction from igny8_core.ai.tracker import StepTracker, ProgressTracker, CostTracker from igny8_core.ai.ai_core import AICore logger = logging.getLogger(__name__) class AIEngine: """ Central orchestrator for all AI functions. Manages lifecycle, progress, logging, retries, cost tracking. """ def __init__(self, celery_task=None, account=None): self.task = celery_task self.account = account self.tracker = ProgressTracker(celery_task) self.step_tracker = StepTracker('ai_engine') self.cost_tracker = CostTracker() def execute(self, fn: BaseAIFunction, payload: dict) -> dict: """ Unified execution pipeline for all AI functions. Phases with improved percentage mapping: - INIT (0-10%): Validation & preparation - PREP (10-25%): Data loading & prompt building - AI_CALL (25-70%): API call to provider (longest phase) - PARSE (70-85%): Response parsing - SAVE (85-98%): Database operations - DONE (98-100%): Finalization """ function_name = fn.get_name() self.step_tracker.function_name = function_name try: # Phase 1: INIT - Validation & Setup (0-10%) validated = fn.validate(payload, self.account) if not validated['valid']: return self._handle_error(validated['error'], fn) self.step_tracker.add_request_step("INIT", "success", "Validation complete") self.tracker.update("INIT", 10, "Validation complete", meta=self.step_tracker.get_meta()) # Phase 2: PREP - Data Loading & Prompt Building (10-25%) data = fn.prepare(payload, self.account) if isinstance(data, (list, tuple)): data_count = len(data) elif isinstance(data, dict): data_count = len(data.get('keywords', [])) if 'keywords' in data else data.get('count', 1) else: data_count = 1 prompt = fn.build_prompt(data, self.account) self.step_tracker.add_request_step("PREP", "success", f"Loaded {data_count} items, built prompt ({len(prompt)} chars)") self.tracker.update("PREP", 25, f"Data prepared: {data_count} items", meta=self.step_tracker.get_meta()) # Phase 3: AI_CALL - Provider API Call (25-70%) ai_core = AICore(account=self.account) model = fn.get_model(self.account) # Track AI call start self.step_tracker.add_response_step("AI_CALL", "success", f"Calling {model or 'default'} model...") self.tracker.update("AI_CALL", 30, f"Sending to {model or 'default'}...", meta=self.step_tracker.get_meta()) try: # Use centralized run_ai_request() with console logging raw_response = ai_core.run_ai_request( prompt=prompt, model=model, max_tokens=4000, temperature=0.7, function_name=fn.get_name() ) except Exception as e: error_msg = f"AI call failed: {str(e)}" logger.error(f"Exception during AI call: {error_msg}", exc_info=True) return self._handle_error(error_msg, fn) if raw_response.get('error'): error_msg = raw_response.get('error', 'Unknown AI error') logger.error(f"AI call returned error: {error_msg}") return self._handle_error(error_msg, fn) if not raw_response.get('content'): error_msg = "AI call returned no content" logger.error(error_msg) return self._handle_error(error_msg, fn) # Track cost self.cost_tracker.record( function_name=function_name, cost=raw_response.get('cost', 0), tokens=raw_response.get('total_tokens', 0), model=raw_response.get('model') ) # Update AI_CALL step with results self.step_tracker.response_steps[-1] = { **self.step_tracker.response_steps[-1], 'message': f"Received {raw_response.get('total_tokens', 0)} tokens, Cost: ${raw_response.get('cost', 0):.6f}", 'duration': raw_response.get('duration') } self.tracker.update("AI_CALL", 70, f"AI response received ({raw_response.get('total_tokens', 0)} tokens)", meta=self.step_tracker.get_meta()) # Phase 4: PARSE - Response Parsing (70-85%) try: response_content = raw_response.get('content', '') parsed = fn.parse_response(response_content, self.step_tracker) if isinstance(parsed, (list, tuple)): parsed_count = len(parsed) elif isinstance(parsed, dict): parsed_count = parsed.get('count', 1) else: parsed_count = 1 self.step_tracker.add_response_step("PARSE", "success", f"Parsed {parsed_count} items from AI response") self.tracker.update("PARSE", 85, f"Parsed {parsed_count} items", meta=self.step_tracker.get_meta()) except Exception as parse_error: error_msg = f"Failed to parse AI response: {str(parse_error)}" logger.error(f"AIEngine: {error_msg}", exc_info=True) logger.error(f"AIEngine: Response content was: {response_content[:500] if response_content else 'None'}...") return self._handle_error(error_msg, fn) # Phase 5: SAVE - Database Operations (85-98%) # Pass step_tracker to save_output so it can add validation steps save_result = fn.save_output(parsed, data, self.account, self.tracker, step_tracker=self.step_tracker) clusters_created = save_result.get('clusters_created', 0) keywords_updated = save_result.get('keywords_updated', 0) self.step_tracker.add_request_step("SAVE", "success", f"Created {clusters_created} clusters, updated {keywords_updated} keywords") self.tracker.update("SAVE", 98, f"Saved: {clusters_created} clusters, {keywords_updated} keywords", meta=self.step_tracker.get_meta()) # Track credit usage after successful save if self.account and raw_response: try: from igny8_core.modules.billing.services import CreditService from igny8_core.modules.billing.models import CreditUsageLog # Calculate credits used (based on tokens or fixed cost) credits_used = self._calculate_credits_for_clustering( keyword_count=len(data.get('keywords', [])) if isinstance(data, dict) else len(data) if isinstance(data, list) else 1, tokens=raw_response.get('total_tokens', 0), cost=raw_response.get('cost', 0) ) # Log credit usage (don't deduct from account.credits, just log) CreditUsageLog.objects.create( account=self.account, operation_type='clustering', credits_used=credits_used, cost_usd=raw_response.get('cost'), model_used=raw_response.get('model', ''), tokens_input=raw_response.get('tokens_input', 0), tokens_output=raw_response.get('tokens_output', 0), related_object_type='cluster', metadata={ 'clusters_created': clusters_created, 'keywords_updated': keywords_updated, 'function_name': function_name } ) except Exception as e: logger.warning(f"Failed to log credit usage: {e}", exc_info=True) # Phase 6: DONE - Finalization (98-100%) self.step_tracker.add_request_step("DONE", "success", "Task completed successfully") self.tracker.update("DONE", 100, "Task complete!", meta=self.step_tracker.get_meta()) # Log to database self._log_to_database(fn, payload, parsed, save_result) return { 'success': True, **save_result, 'request_steps': self.step_tracker.request_steps, 'response_steps': self.step_tracker.response_steps, 'cost': self.cost_tracker.get_total(), 'tokens': self.cost_tracker.get_total_tokens() } except Exception as e: logger.error(f"Error in AIEngine.execute for {function_name}: {str(e)}", exc_info=True) return self._handle_error(str(e), fn, exc_info=True) def _handle_error(self, error: str, fn: BaseAIFunction = None, exc_info=False): """Centralized error handling""" function_name = fn.get_name() if fn else 'unknown' self.step_tracker.add_request_step("Error", "error", error, error=error) error_meta = { 'error': error, 'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error', **self.step_tracker.get_meta() } self.tracker.error(error, meta=error_meta) if exc_info: logger.error(f"Error in {function_name}: {error}", exc_info=True) else: logger.error(f"Error in {function_name}: {error}") self._log_to_database(fn, None, None, None, error=error) return { 'success': False, 'error': error, 'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error', 'request_steps': self.step_tracker.request_steps, 'response_steps': self.step_tracker.response_steps } def _log_to_database( self, fn: BaseAIFunction = None, payload: dict = None, parsed: Any = None, save_result: dict = None, error: str = None ): """Log to unified ai_task_logs table""" try: from igny8_core.ai.models import AITaskLog # Only log if account exists (AITaskLog requires account) if not self.account: logger.warning("Cannot log AI task - no account available") return AITaskLog.objects.create( task_id=self.task.request.id if self.task else None, function_name=fn.get_name() if fn else None, account=self.account, phase=self.tracker.current_phase, message=self.tracker.current_message, status='error' if error else 'success', duration=self.tracker.get_duration(), cost=self.cost_tracker.get_total(), tokens=self.cost_tracker.get_total_tokens(), request_steps=self.step_tracker.request_steps, response_steps=self.step_tracker.response_steps, error=error, payload=payload, result=save_result ) except Exception as e: # Don't fail the task if logging fails logger.warning(f"Failed to log to database: {e}") def _calculate_credits_for_clustering(self, keyword_count, tokens, cost): """Calculate credits used for clustering operation""" # Use plan's cost per request if available, otherwise calculate from tokens if self.account and hasattr(self.account, 'plan') and self.account.plan: plan = self.account.plan # Check if plan has ai_cost_per_request config if hasattr(plan, 'ai_cost_per_request') and plan.ai_cost_per_request: cluster_cost = plan.ai_cost_per_request.get('cluster', 0) if cluster_cost: return int(cluster_cost) # Fallback: 1 credit per 30 keywords (minimum 1) credits = max(1, int(keyword_count / 30)) return credits