348 lines
11 KiB
Python
348 lines
11 KiB
Python
"""
|
||
Progress and Step Tracking utilities for AI framework
|
||
"""
|
||
import time
|
||
import logging
|
||
from typing import List, Dict, Any, Optional, Callable
|
||
from datetime import datetime
|
||
from igny8_core.ai.types import StepLog, ProgressState
|
||
from igny8_core.ai.constants import DEBUG_MODE
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class StepTracker:
|
||
"""Tracks detailed request and response steps for debugging"""
|
||
|
||
def __init__(self, function_name: str):
|
||
self.function_name = function_name
|
||
self.request_steps: List[Dict] = []
|
||
self.response_steps: List[Dict] = []
|
||
self.step_counter = 0
|
||
|
||
def add_request_step(
|
||
self,
|
||
step_name: str,
|
||
status: str = 'success',
|
||
message: str = '',
|
||
error: str = None,
|
||
duration: int = None
|
||
) -> Dict:
|
||
"""Add a request step with automatic timing"""
|
||
self.step_counter += 1
|
||
step = {
|
||
'stepNumber': self.step_counter,
|
||
'stepName': step_name,
|
||
'functionName': self.function_name,
|
||
'status': status,
|
||
'message': message,
|
||
'duration': duration
|
||
}
|
||
if error:
|
||
step['error'] = error
|
||
|
||
self.request_steps.append(step)
|
||
return step
|
||
|
||
def add_response_step(
|
||
self,
|
||
step_name: str,
|
||
status: str = 'success',
|
||
message: str = '',
|
||
error: str = None,
|
||
duration: int = None
|
||
) -> Dict:
|
||
"""Add a response step with automatic timing"""
|
||
self.step_counter += 1
|
||
step = {
|
||
'stepNumber': self.step_counter,
|
||
'stepName': step_name,
|
||
'functionName': self.function_name,
|
||
'status': status,
|
||
'message': message,
|
||
'duration': duration
|
||
}
|
||
if error:
|
||
step['error'] = error
|
||
|
||
self.response_steps.append(step)
|
||
return step
|
||
|
||
def get_meta(self) -> Dict:
|
||
"""Get metadata for progress callback"""
|
||
return {
|
||
'request_steps': self.request_steps,
|
||
'response_steps': self.response_steps
|
||
}
|
||
|
||
|
||
class ProgressTracker:
|
||
"""Tracks progress updates for AI tasks"""
|
||
|
||
def __init__(self, celery_task=None):
|
||
self.task = celery_task
|
||
self.current_phase = 'INIT'
|
||
self.current_message = 'Initializing...'
|
||
self.current_percentage = 0
|
||
self.start_time = time.time()
|
||
self.current = 0
|
||
self.total = 0
|
||
|
||
def update(
|
||
self,
|
||
phase: str,
|
||
percentage: int,
|
||
message: str,
|
||
current: int = None,
|
||
total: int = None,
|
||
current_item: str = None,
|
||
meta: Dict = None
|
||
):
|
||
"""Update progress with consistent format"""
|
||
self.current_phase = phase
|
||
self.current_message = message
|
||
self.current_percentage = percentage
|
||
|
||
if current is not None:
|
||
self.current = current
|
||
if total is not None:
|
||
self.total = total
|
||
|
||
progress_meta = {
|
||
'phase': phase,
|
||
'percentage': percentage,
|
||
'message': message,
|
||
'current': self.current,
|
||
'total': self.total,
|
||
}
|
||
|
||
if current_item:
|
||
progress_meta['current_item'] = current_item
|
||
|
||
if meta:
|
||
progress_meta.update(meta)
|
||
|
||
# Update Celery task state if available
|
||
if self.task:
|
||
try:
|
||
self.task.update_state(
|
||
state='PROGRESS',
|
||
meta=progress_meta
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to update Celery task state: {e}")
|
||
|
||
logger.info(f"[{phase}] {percentage}%: {message}")
|
||
|
||
def set_phase(self, phase: str, percentage: int, message: str, meta: Dict = None):
|
||
"""Set progress phase"""
|
||
self.update(phase, percentage, message, meta=meta)
|
||
|
||
def complete(self, message: str = "Task complete!", meta: Dict = None):
|
||
"""Mark task as complete"""
|
||
final_meta = {
|
||
'phase': 'DONE',
|
||
'percentage': 100,
|
||
'message': message,
|
||
'status': 'success'
|
||
}
|
||
if meta:
|
||
final_meta.update(meta)
|
||
|
||
if self.task:
|
||
try:
|
||
self.task.update_state(
|
||
state='SUCCESS',
|
||
meta=final_meta
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to update Celery task state: {e}")
|
||
|
||
def error(self, error_message: str, meta: Dict = None):
|
||
"""Mark task as failed"""
|
||
error_meta = {
|
||
'phase': 'ERROR',
|
||
'percentage': 0,
|
||
'message': f'Error: {error_message}',
|
||
'status': 'error',
|
||
'error': error_message
|
||
}
|
||
if meta:
|
||
error_meta.update(meta)
|
||
|
||
if self.task:
|
||
try:
|
||
self.task.update_state(
|
||
state='FAILURE',
|
||
meta=error_meta
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to update Celery task state: {e}")
|
||
|
||
def get_duration(self) -> int:
|
||
"""Get elapsed time in milliseconds"""
|
||
return int((time.time() - self.start_time) * 1000)
|
||
|
||
def update_ai_progress(self, state: str, meta: Dict):
|
||
"""Callback for AI processor progress updates"""
|
||
if isinstance(meta, dict):
|
||
percentage = meta.get('percentage', self.current_percentage)
|
||
message = meta.get('message', self.current_message)
|
||
phase = meta.get('phase', self.current_phase)
|
||
self.update(phase, percentage, message, meta=meta)
|
||
|
||
|
||
class CostTracker:
|
||
"""Tracks API costs and token usage"""
|
||
|
||
def __init__(self):
|
||
self.total_cost = 0.0
|
||
self.total_tokens = 0
|
||
self.operations = []
|
||
|
||
def record(self, function_name: str, cost: float, tokens: int, model: str = None):
|
||
"""Record an API call cost"""
|
||
self.total_cost += cost
|
||
self.total_tokens += tokens
|
||
self.operations.append({
|
||
'function': function_name,
|
||
'cost': cost,
|
||
'tokens': tokens,
|
||
'model': model
|
||
})
|
||
|
||
def get_total(self) -> float:
|
||
"""Get total cost"""
|
||
return self.total_cost
|
||
|
||
def get_total_tokens(self) -> int:
|
||
"""Get total tokens"""
|
||
return self.total_tokens
|
||
|
||
def get_operations(self) -> List[Dict]:
|
||
"""Get all operations"""
|
||
return self.operations
|
||
|
||
|
||
class ConsoleStepTracker:
|
||
"""
|
||
Lightweight console-based step tracker for AI functions.
|
||
Logs each step to console with timestamps and clear labels.
|
||
Only logs if DEBUG_MODE is True.
|
||
"""
|
||
|
||
def __init__(self, function_name: str):
|
||
self.function_name = function_name
|
||
self.start_time = time.time()
|
||
self.steps = []
|
||
self.current_phase = None
|
||
|
||
# Debug: Verify DEBUG_MODE is enabled
|
||
import sys
|
||
if DEBUG_MODE:
|
||
init_msg = f"[DEBUG] ConsoleStepTracker initialized for '{function_name}' - DEBUG_MODE is ENABLED"
|
||
logger.info(init_msg)
|
||
print(init_msg, flush=True, file=sys.stdout)
|
||
else:
|
||
init_msg = f"[WARNING] ConsoleStepTracker initialized for '{function_name}' - DEBUG_MODE is DISABLED"
|
||
logger.warning(init_msg)
|
||
print(init_msg, flush=True, file=sys.stdout)
|
||
|
||
def _log(self, phase: str, message: str, status: str = 'info'):
|
||
"""Internal logging method that checks DEBUG_MODE"""
|
||
if not DEBUG_MODE:
|
||
return
|
||
|
||
import sys
|
||
timestamp = datetime.now().strftime('%H:%M:%S')
|
||
phase_label = phase.upper()
|
||
|
||
if status == 'error':
|
||
log_msg = f"[{timestamp}] [{self.function_name}] [{phase_label}] [ERROR] {message}"
|
||
# Use logger.error for errors so they're always visible
|
||
logger.error(log_msg)
|
||
elif status == 'success':
|
||
log_msg = f"[{timestamp}] [{self.function_name}] [{phase_label}] ✅ {message}"
|
||
logger.info(log_msg)
|
||
else:
|
||
log_msg = f"[{timestamp}] [{self.function_name}] [{phase_label}] {message}"
|
||
logger.info(log_msg)
|
||
|
||
# Also print to stdout for immediate visibility (works in Celery worker logs)
|
||
print(log_msg, flush=True, file=sys.stdout)
|
||
|
||
self.steps.append({
|
||
'timestamp': timestamp,
|
||
'phase': phase,
|
||
'message': message,
|
||
'status': status
|
||
})
|
||
self.current_phase = phase
|
||
|
||
def init(self, message: str = "Task started"):
|
||
"""Log initialization phase"""
|
||
self._log('INIT', message)
|
||
|
||
def prep(self, message: str):
|
||
"""Log preparation phase"""
|
||
self._log('PREP', message)
|
||
|
||
def ai_call(self, message: str):
|
||
"""Log AI call phase"""
|
||
self._log('AI_CALL', message)
|
||
|
||
def parse(self, message: str):
|
||
"""Log parsing phase"""
|
||
self._log('PARSE', message)
|
||
|
||
def save(self, message: str):
|
||
"""Log save phase"""
|
||
self._log('SAVE', message)
|
||
|
||
def done(self, message: str = "Execution completed"):
|
||
"""Log completion"""
|
||
duration = time.time() - self.start_time
|
||
self._log('DONE', f"{message} (Duration: {duration:.2f}s)", status='success')
|
||
if DEBUG_MODE:
|
||
import sys
|
||
complete_msg = f"[{self.function_name}] === AI Task Complete ==="
|
||
logger.info(complete_msg)
|
||
print(complete_msg, flush=True, file=sys.stdout)
|
||
|
||
def error(self, error_type: str, message: str, exception: Exception = None):
|
||
"""Log error with standardized format"""
|
||
error_msg = f"{error_type} – {message}"
|
||
if exception:
|
||
error_msg += f" ({type(exception).__name__})"
|
||
self._log(self.current_phase or 'ERROR', error_msg, status='error')
|
||
if DEBUG_MODE and exception:
|
||
import sys
|
||
import traceback
|
||
error_trace_msg = f"[{self.function_name}] [ERROR] Stack trace:"
|
||
logger.error(error_trace_msg, exc_info=exception)
|
||
print(error_trace_msg, flush=True, file=sys.stdout)
|
||
traceback.print_exc(file=sys.stdout)
|
||
|
||
def retry(self, attempt: int, max_attempts: int, reason: str = ""):
|
||
"""Log retry attempt"""
|
||
msg = f"Retry attempt {attempt}/{max_attempts}"
|
||
if reason:
|
||
msg += f" – {reason}"
|
||
self._log('AI_CALL', msg, status='info')
|
||
|
||
def timeout(self, timeout_seconds: int):
|
||
"""Log timeout"""
|
||
self.error('Timeout', f"Request timeout after {timeout_seconds}s")
|
||
|
||
def rate_limit(self, retry_after: str):
|
||
"""Log rate limit"""
|
||
self.error('RateLimit', f"OpenAI rate limit hit, retry in {retry_after}s")
|
||
|
||
def malformed_json(self, details: str = ""):
|
||
"""Log JSON parsing error"""
|
||
msg = "Failed to parse model response: Unexpected JSON"
|
||
if details:
|
||
msg += f" – {details}"
|
||
self.error('MalformedJSON', msg)
|
||
|