340 lines
11 KiB
Python
340 lines
11 KiB
Python
"""
|
|
Model Registry Service
|
|
Central registry for AI model configurations with caching.
|
|
Replaces hardcoded MODEL_RATES and IMAGE_MODEL_RATES from constants.py
|
|
|
|
This service provides:
|
|
- Database-driven model configuration (from AIModelConfig)
|
|
- Fallback to constants.py for backward compatibility
|
|
- Caching for performance
|
|
- Cost calculation methods
|
|
|
|
Usage:
|
|
from igny8_core.ai.model_registry import ModelRegistry
|
|
|
|
# Get model config
|
|
model = ModelRegistry.get_model('gpt-4o-mini')
|
|
|
|
# Get rate
|
|
input_rate = ModelRegistry.get_rate('gpt-4o-mini', 'input')
|
|
|
|
# Calculate cost
|
|
cost = ModelRegistry.calculate_cost('gpt-4o-mini', input_tokens=1000, output_tokens=500)
|
|
"""
|
|
import logging
|
|
from decimal import Decimal
|
|
from typing import Optional, Dict, Any
|
|
from django.core.cache import cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cache TTL in seconds (5 minutes)
|
|
MODEL_CACHE_TTL = 300
|
|
|
|
# Cache key prefix
|
|
CACHE_KEY_PREFIX = 'ai_model_'
|
|
|
|
|
|
class ModelRegistry:
|
|
"""
|
|
Central registry for AI model configurations with caching.
|
|
Uses AIModelConfig from database with fallback to constants.py
|
|
"""
|
|
|
|
@classmethod
|
|
def _get_cache_key(cls, model_id: str) -> str:
|
|
"""Generate cache key for model"""
|
|
return f"{CACHE_KEY_PREFIX}{model_id}"
|
|
|
|
@classmethod
|
|
def _get_from_db(cls, model_id: str) -> Optional[Any]:
|
|
"""Get model config from database"""
|
|
try:
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
return AIModelConfig.objects.filter(
|
|
model_name=model_id,
|
|
is_active=True
|
|
).first()
|
|
except Exception as e:
|
|
logger.debug(f"Could not fetch model {model_id} from DB: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def _get_from_constants(cls, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get model config from constants.py as fallback.
|
|
Returns a dict mimicking AIModelConfig attributes.
|
|
"""
|
|
from igny8_core.ai.constants import MODEL_RATES, IMAGE_MODEL_RATES
|
|
|
|
# Check text models first
|
|
if model_id in MODEL_RATES:
|
|
rates = MODEL_RATES[model_id]
|
|
return {
|
|
'model_name': model_id,
|
|
'display_name': model_id,
|
|
'model_type': 'text',
|
|
'provider': 'openai',
|
|
'input_cost_per_1m': Decimal(str(rates.get('input', 0))),
|
|
'output_cost_per_1m': Decimal(str(rates.get('output', 0))),
|
|
'cost_per_image': None,
|
|
'is_active': True,
|
|
'_from_constants': True
|
|
}
|
|
|
|
# Check image models
|
|
if model_id in IMAGE_MODEL_RATES:
|
|
cost = IMAGE_MODEL_RATES[model_id]
|
|
return {
|
|
'model_name': model_id,
|
|
'display_name': model_id,
|
|
'model_type': 'image',
|
|
'provider': 'openai' if 'dall-e' in model_id else 'runware',
|
|
'input_cost_per_1m': None,
|
|
'output_cost_per_1m': None,
|
|
'cost_per_image': Decimal(str(cost)),
|
|
'is_active': True,
|
|
'_from_constants': True
|
|
}
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def get_model(cls, model_id: str) -> Optional[Any]:
|
|
"""
|
|
Get model configuration by model_id.
|
|
|
|
Order of lookup:
|
|
1. Cache
|
|
2. Database (AIModelConfig)
|
|
3. constants.py fallback
|
|
|
|
Args:
|
|
model_id: The model identifier (e.g., 'gpt-4o-mini', 'dall-e-3')
|
|
|
|
Returns:
|
|
AIModelConfig instance or dict with model config, None if not found
|
|
"""
|
|
cache_key = cls._get_cache_key(model_id)
|
|
|
|
# Try cache first
|
|
cached = cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Try database
|
|
model_config = cls._get_from_db(model_id)
|
|
|
|
if model_config:
|
|
cache.set(cache_key, model_config, MODEL_CACHE_TTL)
|
|
return model_config
|
|
|
|
# Fallback to constants
|
|
fallback = cls._get_from_constants(model_id)
|
|
if fallback:
|
|
cache.set(cache_key, fallback, MODEL_CACHE_TTL)
|
|
return fallback
|
|
|
|
logger.warning(f"Model {model_id} not found in DB or constants")
|
|
return None
|
|
|
|
@classmethod
|
|
def get_rate(cls, model_id: str, rate_type: str) -> Decimal:
|
|
"""
|
|
Get specific rate for a model.
|
|
|
|
Args:
|
|
model_id: The model identifier
|
|
rate_type: 'input', 'output' (for text models) or 'image' (for image models)
|
|
|
|
Returns:
|
|
Decimal rate value, 0 if not found
|
|
"""
|
|
model = cls.get_model(model_id)
|
|
if not model:
|
|
return Decimal('0')
|
|
|
|
# Handle dict (from constants fallback)
|
|
if isinstance(model, dict):
|
|
if rate_type == 'input':
|
|
return model.get('input_cost_per_1m') or Decimal('0')
|
|
elif rate_type == 'output':
|
|
return model.get('output_cost_per_1m') or Decimal('0')
|
|
elif rate_type == 'image':
|
|
return model.get('cost_per_image') or Decimal('0')
|
|
return Decimal('0')
|
|
|
|
# Handle AIModelConfig instance
|
|
if rate_type == 'input':
|
|
return model.input_cost_per_1m or Decimal('0')
|
|
elif rate_type == 'output':
|
|
return model.output_cost_per_1m or Decimal('0')
|
|
elif rate_type == 'image':
|
|
return model.cost_per_image or Decimal('0')
|
|
|
|
return Decimal('0')
|
|
|
|
@classmethod
|
|
def calculate_cost(cls, model_id: str, input_tokens: int = 0, output_tokens: int = 0, num_images: int = 0) -> Decimal:
|
|
"""
|
|
Calculate cost for model usage.
|
|
|
|
For text models: Uses input/output token counts
|
|
For image models: Uses num_images
|
|
|
|
Args:
|
|
model_id: The model identifier
|
|
input_tokens: Number of input tokens (for text models)
|
|
output_tokens: Number of output tokens (for text models)
|
|
num_images: Number of images (for image models)
|
|
|
|
Returns:
|
|
Decimal cost in USD
|
|
"""
|
|
model = cls.get_model(model_id)
|
|
if not model:
|
|
return Decimal('0')
|
|
|
|
# Determine model type
|
|
model_type = model.get('model_type') if isinstance(model, dict) else model.model_type
|
|
|
|
if model_type == 'text':
|
|
input_rate = cls.get_rate(model_id, 'input')
|
|
output_rate = cls.get_rate(model_id, 'output')
|
|
|
|
cost = (
|
|
(Decimal(input_tokens) * input_rate) +
|
|
(Decimal(output_tokens) * output_rate)
|
|
) / Decimal('1000000')
|
|
|
|
return cost
|
|
|
|
elif model_type == 'image':
|
|
image_rate = cls.get_rate(model_id, 'image')
|
|
return image_rate * Decimal(num_images)
|
|
|
|
return Decimal('0')
|
|
|
|
@classmethod
|
|
def get_default_model(cls, model_type: str = 'text') -> Optional[str]:
|
|
"""
|
|
Get the default model for a given type.
|
|
|
|
Args:
|
|
model_type: 'text' or 'image'
|
|
|
|
Returns:
|
|
model_id string or None
|
|
"""
|
|
try:
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
default = AIModelConfig.objects.filter(
|
|
model_type=model_type,
|
|
is_active=True,
|
|
is_default=True
|
|
).first()
|
|
|
|
if default:
|
|
return default.model_name
|
|
except Exception as e:
|
|
logger.debug(f"Could not get default {model_type} model from DB: {e}")
|
|
|
|
# Fallback to constants
|
|
from igny8_core.ai.constants import DEFAULT_AI_MODEL
|
|
if model_type == 'text':
|
|
return DEFAULT_AI_MODEL
|
|
elif model_type == 'image':
|
|
return 'dall-e-3'
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def list_models(cls, model_type: Optional[str] = None, provider: Optional[str] = None) -> list:
|
|
"""
|
|
List all available models, optionally filtered by type or provider.
|
|
|
|
Args:
|
|
model_type: Filter by 'text', 'image', or 'embedding'
|
|
provider: Filter by 'openai', 'anthropic', 'runware', etc.
|
|
|
|
Returns:
|
|
List of model configs
|
|
"""
|
|
models = []
|
|
|
|
try:
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
queryset = AIModelConfig.objects.filter(is_active=True)
|
|
|
|
if model_type:
|
|
queryset = queryset.filter(model_type=model_type)
|
|
if provider:
|
|
queryset = queryset.filter(provider=provider)
|
|
|
|
models = list(queryset.order_by('sort_order', 'model_name'))
|
|
except Exception as e:
|
|
logger.debug(f"Could not list models from DB: {e}")
|
|
|
|
# Add models from constants if not in DB
|
|
if not models:
|
|
from igny8_core.ai.constants import MODEL_RATES, IMAGE_MODEL_RATES
|
|
|
|
if model_type in (None, 'text'):
|
|
for model_id in MODEL_RATES:
|
|
fallback = cls._get_from_constants(model_id)
|
|
if fallback:
|
|
models.append(fallback)
|
|
|
|
if model_type in (None, 'image'):
|
|
for model_id in IMAGE_MODEL_RATES:
|
|
fallback = cls._get_from_constants(model_id)
|
|
if fallback:
|
|
models.append(fallback)
|
|
|
|
return models
|
|
|
|
@classmethod
|
|
def clear_cache(cls, model_id: Optional[str] = None):
|
|
"""
|
|
Clear model cache.
|
|
|
|
Args:
|
|
model_id: Clear specific model cache, or all if None
|
|
"""
|
|
if model_id:
|
|
cache.delete(cls._get_cache_key(model_id))
|
|
else:
|
|
# Clear all model caches - use pattern if available
|
|
try:
|
|
from django.core.cache import caches
|
|
default_cache = caches['default']
|
|
if hasattr(default_cache, 'delete_pattern'):
|
|
default_cache.delete_pattern(f"{CACHE_KEY_PREFIX}*")
|
|
else:
|
|
# Fallback: clear known models
|
|
from igny8_core.ai.constants import MODEL_RATES, IMAGE_MODEL_RATES
|
|
for model_id in list(MODEL_RATES.keys()) + list(IMAGE_MODEL_RATES.keys()):
|
|
cache.delete(cls._get_cache_key(model_id))
|
|
except Exception as e:
|
|
logger.warning(f"Could not clear all model caches: {e}")
|
|
|
|
@classmethod
|
|
def validate_model(cls, model_id: str) -> bool:
|
|
"""
|
|
Check if a model ID is valid and active.
|
|
|
|
Args:
|
|
model_id: The model identifier to validate
|
|
|
|
Returns:
|
|
True if model exists and is active, False otherwise
|
|
"""
|
|
model = cls.get_model(model_id)
|
|
if not model:
|
|
return False
|
|
|
|
# Check if active
|
|
if isinstance(model, dict):
|
|
return model.get('is_active', True)
|
|
return model.is_active
|