272 lines
13 KiB
Python
272 lines
13 KiB
Python
"""
|
|
AI Engine - Central orchestrator for all AI functions
|
|
"""
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
from igny8_core.ai.base import BaseAIFunction
|
|
from igny8_core.ai.tracker import StepTracker, ProgressTracker, CostTracker
|
|
from igny8_core.ai.processor import AIProcessor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AIEngine:
|
|
"""
|
|
Central orchestrator for all AI functions.
|
|
Manages lifecycle, progress, logging, retries, cost tracking.
|
|
"""
|
|
|
|
def __init__(self, celery_task=None, account=None):
|
|
self.task = celery_task
|
|
self.account = account
|
|
self.tracker = ProgressTracker(celery_task)
|
|
self.step_tracker = StepTracker('ai_engine')
|
|
self.cost_tracker = CostTracker()
|
|
|
|
def execute(self, fn: BaseAIFunction, payload: dict) -> dict:
|
|
"""
|
|
Unified execution pipeline for all AI functions.
|
|
|
|
Phases with improved percentage mapping:
|
|
- INIT (0-10%): Validation & preparation
|
|
- PREP (10-25%): Data loading & prompt building
|
|
- AI_CALL (25-70%): API call to provider (longest phase)
|
|
- PARSE (70-85%): Response parsing
|
|
- SAVE (85-98%): Database operations
|
|
- DONE (98-100%): Finalization
|
|
"""
|
|
function_name = fn.get_name()
|
|
self.step_tracker.function_name = function_name
|
|
|
|
try:
|
|
# Phase 1: INIT - Validation & Setup (0-10%)
|
|
validated = fn.validate(payload, self.account)
|
|
if not validated['valid']:
|
|
return self._handle_error(validated['error'], fn)
|
|
|
|
self.step_tracker.add_request_step("INIT", "success", "Validation complete")
|
|
self.tracker.update("INIT", 10, "Validation complete", meta=self.step_tracker.get_meta())
|
|
|
|
# Phase 2: PREP - Data Loading & Prompt Building (10-25%)
|
|
data = fn.prepare(payload, self.account)
|
|
if isinstance(data, (list, tuple)):
|
|
data_count = len(data)
|
|
elif isinstance(data, dict):
|
|
data_count = len(data.get('keywords', [])) if 'keywords' in data else data.get('count', 1)
|
|
else:
|
|
data_count = 1
|
|
|
|
prompt = fn.build_prompt(data, self.account)
|
|
self.step_tracker.add_request_step("PREP", "success", f"Loaded {data_count} items, built prompt ({len(prompt)} chars)")
|
|
self.tracker.update("PREP", 25, f"Data prepared: {data_count} items", meta=self.step_tracker.get_meta())
|
|
|
|
# Phase 3: AI_CALL - Provider API Call (25-70%)
|
|
processor = AIProcessor(account=self.account)
|
|
model = fn.get_model(self.account)
|
|
|
|
# Track AI call start
|
|
self.step_tracker.add_response_step("AI_CALL", "success", f"Calling {model or 'default'} model...")
|
|
self.tracker.update("AI_CALL", 30, f"Sending to {model or 'default'}...", meta=self.step_tracker.get_meta())
|
|
|
|
try:
|
|
raw_response = processor.call(
|
|
prompt,
|
|
model=model,
|
|
# Don't pass response_steps - the processor ignores it anyway
|
|
# Step tracking is handled by the engine
|
|
progress_callback=lambda state, meta: self.tracker.update_ai_progress(state, {
|
|
**meta,
|
|
**self.step_tracker.get_meta()
|
|
})
|
|
)
|
|
except Exception as e:
|
|
error_msg = f"AI call failed: {str(e)}"
|
|
logger.error(f"Exception during AI call: {error_msg}", exc_info=True)
|
|
return self._handle_error(error_msg, fn)
|
|
|
|
if raw_response.get('error'):
|
|
error_msg = raw_response.get('error', 'Unknown AI error')
|
|
logger.error(f"AI call returned error: {error_msg}")
|
|
return self._handle_error(error_msg, fn)
|
|
|
|
if not raw_response.get('content'):
|
|
error_msg = "AI call returned no content"
|
|
logger.error(error_msg)
|
|
return self._handle_error(error_msg, fn)
|
|
|
|
# Track cost
|
|
self.cost_tracker.record(
|
|
function_name=function_name,
|
|
cost=raw_response.get('cost', 0),
|
|
tokens=raw_response.get('total_tokens', 0),
|
|
model=raw_response.get('model')
|
|
)
|
|
|
|
# Update AI_CALL step with results
|
|
self.step_tracker.response_steps[-1] = {
|
|
**self.step_tracker.response_steps[-1],
|
|
'message': f"Received {raw_response.get('total_tokens', 0)} tokens, Cost: ${raw_response.get('cost', 0):.6f}",
|
|
'duration': raw_response.get('duration')
|
|
}
|
|
self.tracker.update("AI_CALL", 70, f"AI response received ({raw_response.get('total_tokens', 0)} tokens)", meta=self.step_tracker.get_meta())
|
|
|
|
# Phase 4: PARSE - Response Parsing (70-85%)
|
|
try:
|
|
response_content = raw_response.get('content', '')
|
|
parsed = fn.parse_response(response_content, self.step_tracker)
|
|
|
|
if isinstance(parsed, (list, tuple)):
|
|
parsed_count = len(parsed)
|
|
elif isinstance(parsed, dict):
|
|
parsed_count = parsed.get('count', 1)
|
|
else:
|
|
parsed_count = 1
|
|
|
|
self.step_tracker.add_response_step("PARSE", "success", f"Parsed {parsed_count} items from AI response")
|
|
self.tracker.update("PARSE", 85, f"Parsed {parsed_count} items", meta=self.step_tracker.get_meta())
|
|
except Exception as parse_error:
|
|
error_msg = f"Failed to parse AI response: {str(parse_error)}"
|
|
logger.error(f"AIEngine: {error_msg}", exc_info=True)
|
|
logger.error(f"AIEngine: Response content was: {response_content[:500] if response_content else 'None'}...")
|
|
return self._handle_error(error_msg, fn)
|
|
|
|
# Phase 5: SAVE - Database Operations (85-98%)
|
|
# Pass step_tracker to save_output so it can add validation steps
|
|
save_result = fn.save_output(parsed, data, self.account, self.tracker, step_tracker=self.step_tracker)
|
|
clusters_created = save_result.get('clusters_created', 0)
|
|
keywords_updated = save_result.get('keywords_updated', 0)
|
|
self.step_tracker.add_request_step("SAVE", "success", f"Created {clusters_created} clusters, updated {keywords_updated} keywords")
|
|
self.tracker.update("SAVE", 98, f"Saved: {clusters_created} clusters, {keywords_updated} keywords", meta=self.step_tracker.get_meta())
|
|
|
|
# Track credit usage after successful save
|
|
if self.account and raw_response:
|
|
try:
|
|
from igny8_core.modules.billing.services import CreditService
|
|
from igny8_core.modules.billing.models import CreditUsageLog
|
|
|
|
# Calculate credits used (based on tokens or fixed cost)
|
|
credits_used = self._calculate_credits_for_clustering(
|
|
keyword_count=len(data.get('keywords', [])) if isinstance(data, dict) else len(data) if isinstance(data, list) else 1,
|
|
tokens=raw_response.get('total_tokens', 0),
|
|
cost=raw_response.get('cost', 0)
|
|
)
|
|
|
|
# Log credit usage (don't deduct from account.credits, just log)
|
|
CreditUsageLog.objects.create(
|
|
account=self.account,
|
|
operation_type='clustering',
|
|
credits_used=credits_used,
|
|
cost_usd=raw_response.get('cost'),
|
|
model_used=raw_response.get('model', ''),
|
|
tokens_input=raw_response.get('tokens_input', 0),
|
|
tokens_output=raw_response.get('tokens_output', 0),
|
|
related_object_type='cluster',
|
|
metadata={
|
|
'clusters_created': clusters_created,
|
|
'keywords_updated': keywords_updated,
|
|
'function_name': function_name
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log credit usage: {e}", exc_info=True)
|
|
|
|
# Phase 6: DONE - Finalization (98-100%)
|
|
self.step_tracker.add_request_step("DONE", "success", "Task completed successfully")
|
|
self.tracker.update("DONE", 100, "Task complete!", meta=self.step_tracker.get_meta())
|
|
|
|
# Log to database
|
|
self._log_to_database(fn, payload, parsed, save_result)
|
|
|
|
return {
|
|
'success': True,
|
|
**save_result,
|
|
'request_steps': self.step_tracker.request_steps,
|
|
'response_steps': self.step_tracker.response_steps,
|
|
'cost': self.cost_tracker.get_total(),
|
|
'tokens': self.cost_tracker.get_total_tokens()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in AIEngine.execute for {function_name}: {str(e)}", exc_info=True)
|
|
return self._handle_error(str(e), fn, exc_info=True)
|
|
|
|
def _handle_error(self, error: str, fn: BaseAIFunction = None, exc_info=False):
|
|
"""Centralized error handling"""
|
|
function_name = fn.get_name() if fn else 'unknown'
|
|
self.step_tracker.add_request_step("Error", "error", error, error=error)
|
|
|
|
error_meta = {
|
|
'error': error,
|
|
'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error',
|
|
**self.step_tracker.get_meta()
|
|
}
|
|
self.tracker.error(error, meta=error_meta)
|
|
|
|
if exc_info:
|
|
logger.error(f"Error in {function_name}: {error}", exc_info=True)
|
|
else:
|
|
logger.error(f"Error in {function_name}: {error}")
|
|
|
|
self._log_to_database(fn, None, None, None, error=error)
|
|
|
|
return {
|
|
'success': False,
|
|
'error': error,
|
|
'error_type': type(error).__name__ if isinstance(error, Exception) else 'Error',
|
|
'request_steps': self.step_tracker.request_steps,
|
|
'response_steps': self.step_tracker.response_steps
|
|
}
|
|
|
|
def _log_to_database(
|
|
self,
|
|
fn: BaseAIFunction = None,
|
|
payload: dict = None,
|
|
parsed: Any = None,
|
|
save_result: dict = None,
|
|
error: str = None
|
|
):
|
|
"""Log to unified ai_task_logs table"""
|
|
try:
|
|
from igny8_core.ai.models import AITaskLog
|
|
|
|
# Only log if account exists (AITaskLog requires account)
|
|
if not self.account:
|
|
logger.warning("Cannot log AI task - no account available")
|
|
return
|
|
|
|
AITaskLog.objects.create(
|
|
task_id=self.task.request.id if self.task else None,
|
|
function_name=fn.get_name() if fn else None,
|
|
account=self.account,
|
|
phase=self.tracker.current_phase,
|
|
message=self.tracker.current_message,
|
|
status='error' if error else 'success',
|
|
duration=self.tracker.get_duration(),
|
|
cost=self.cost_tracker.get_total(),
|
|
tokens=self.cost_tracker.get_total_tokens(),
|
|
request_steps=self.step_tracker.request_steps,
|
|
response_steps=self.step_tracker.response_steps,
|
|
error=error,
|
|
payload=payload,
|
|
result=save_result
|
|
)
|
|
except Exception as e:
|
|
# Don't fail the task if logging fails
|
|
logger.warning(f"Failed to log to database: {e}")
|
|
|
|
def _calculate_credits_for_clustering(self, keyword_count, tokens, cost):
|
|
"""Calculate credits used for clustering operation"""
|
|
# Use plan's cost per request if available, otherwise calculate from tokens
|
|
if self.account and hasattr(self.account, 'plan') and self.account.plan:
|
|
plan = self.account.plan
|
|
# Check if plan has ai_cost_per_request config
|
|
if hasattr(plan, 'ai_cost_per_request') and plan.ai_cost_per_request:
|
|
cluster_cost = plan.ai_cost_per_request.get('cluster', 0)
|
|
if cluster_cost:
|
|
return int(cluster_cost)
|
|
|
|
# Fallback: 1 credit per 30 keywords (minimum 1)
|
|
credits = max(1, int(keyword_count / 30))
|
|
return credits
|
|
|