378 lines
12 KiB
Python
378 lines
12 KiB
Python
"""
|
|
Model Registry Service
|
|
Central registry for AI model configurations with caching.
|
|
|
|
This service provides:
|
|
- Database-driven model configuration (from AIModelConfig)
|
|
- Integration provider API key retrieval (from IntegrationProvider)
|
|
- Caching for performance
|
|
- Cost calculation methods
|
|
|
|
Usage:
|
|
from igny8_core.ai.model_registry import ModelRegistry
|
|
|
|
# Get model config
|
|
model = ModelRegistry.get_model('gpt-4o-mini')
|
|
|
|
# Get rate
|
|
input_rate = ModelRegistry.get_rate('gpt-4o-mini', 'input')
|
|
|
|
# Calculate cost
|
|
cost = ModelRegistry.calculate_cost('gpt-4o-mini', input_tokens=1000, output_tokens=500)
|
|
|
|
# Get API key for a provider
|
|
api_key = ModelRegistry.get_api_key('openai')
|
|
"""
|
|
import logging
|
|
from decimal import Decimal
|
|
from typing import Optional, Dict, Any
|
|
from django.core.cache import cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cache TTL in seconds (5 minutes)
|
|
MODEL_CACHE_TTL = 300
|
|
|
|
# Cache key prefix
|
|
CACHE_KEY_PREFIX = 'ai_model_'
|
|
PROVIDER_CACHE_PREFIX = 'provider_'
|
|
|
|
|
|
class ModelRegistry:
|
|
"""
|
|
Central registry for AI model configurations with caching.
|
|
Uses AIModelConfig from database for model configs.
|
|
Uses IntegrationProvider for API keys.
|
|
"""
|
|
|
|
@classmethod
|
|
def _get_cache_key(cls, model_id: str) -> str:
|
|
"""Generate cache key for model"""
|
|
return f"{CACHE_KEY_PREFIX}{model_id}"
|
|
|
|
@classmethod
|
|
def _get_provider_cache_key(cls, provider_id: str) -> str:
|
|
"""Generate cache key for provider"""
|
|
return f"{PROVIDER_CACHE_PREFIX}{provider_id}"
|
|
|
|
@classmethod
|
|
def _get_from_db(cls, model_id: str) -> Optional[Any]:
|
|
"""Get model config from database"""
|
|
try:
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
return AIModelConfig.objects.filter(
|
|
model_name=model_id,
|
|
is_active=True
|
|
).first()
|
|
except Exception as e:
|
|
logger.debug(f"Could not fetch model {model_id} from DB: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def get_model(cls, model_id: str) -> Optional[Any]:
|
|
"""
|
|
Get model configuration by model_id.
|
|
|
|
Order of lookup:
|
|
1. Cache
|
|
2. Database (AIModelConfig)
|
|
|
|
Args:
|
|
model_id: The model identifier (e.g., 'gpt-4o-mini', 'dall-e-3')
|
|
|
|
Returns:
|
|
AIModelConfig instance, None if not found
|
|
"""
|
|
cache_key = cls._get_cache_key(model_id)
|
|
|
|
# Try cache first
|
|
cached = cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Try database
|
|
model_config = cls._get_from_db(model_id)
|
|
|
|
if model_config:
|
|
cache.set(cache_key, model_config, MODEL_CACHE_TTL)
|
|
return model_config
|
|
|
|
logger.warning(f"Model {model_id} not found in database")
|
|
return None
|
|
|
|
@classmethod
|
|
def get_rate(cls, model_id: str, rate_type: str) -> Decimal:
|
|
"""
|
|
Get specific rate for a model.
|
|
|
|
Args:
|
|
model_id: The model identifier
|
|
rate_type: 'input', 'output' (for text models) or 'image' (for image models)
|
|
|
|
Returns:
|
|
Decimal rate value, 0 if not found
|
|
"""
|
|
model = cls.get_model(model_id)
|
|
if not model:
|
|
return Decimal('0')
|
|
|
|
# Handle AIModelConfig instance
|
|
if rate_type == 'input':
|
|
return model.cost_per_1k_input or Decimal('0')
|
|
elif rate_type == 'output':
|
|
return model.cost_per_1k_output or Decimal('0')
|
|
elif rate_type == 'image':
|
|
return model.cost_per_image or Decimal('0')
|
|
|
|
return Decimal('0')
|
|
|
|
@classmethod
|
|
def calculate_cost(cls, model_id: str, input_tokens: int = 0, output_tokens: int = 0, num_images: int = 0) -> Decimal:
|
|
"""
|
|
Calculate cost for model usage.
|
|
|
|
For text models: Uses input/output token counts
|
|
For image models: Uses num_images
|
|
|
|
Args:
|
|
model_id: The model identifier
|
|
input_tokens: Number of input tokens (for text models)
|
|
output_tokens: Number of output tokens (for text models)
|
|
num_images: Number of images (for image models)
|
|
|
|
Returns:
|
|
Decimal cost in USD
|
|
"""
|
|
model = cls.get_model(model_id)
|
|
if not model:
|
|
return Decimal('0')
|
|
|
|
# Get model type from AIModelConfig
|
|
model_type = model.model_type
|
|
|
|
if model_type == 'text':
|
|
input_rate = cls.get_rate(model_id, 'input')
|
|
output_rate = cls.get_rate(model_id, 'output')
|
|
|
|
cost = (
|
|
(Decimal(input_tokens) * input_rate) +
|
|
(Decimal(output_tokens) * output_rate)
|
|
) / Decimal('1000000')
|
|
|
|
return cost
|
|
|
|
elif model_type == 'image':
|
|
image_rate = cls.get_rate(model_id, 'image')
|
|
return image_rate * Decimal(num_images)
|
|
|
|
return Decimal('0')
|
|
|
|
@classmethod
|
|
def get_default_model(cls, model_type: str = 'text') -> Optional[str]:
|
|
"""
|
|
Get the default model for a given type from database.
|
|
|
|
Args:
|
|
model_type: 'text' or 'image'
|
|
|
|
Returns:
|
|
model_id string or None
|
|
"""
|
|
try:
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
default = AIModelConfig.objects.filter(
|
|
model_type=model_type,
|
|
is_active=True,
|
|
is_default=True
|
|
).first()
|
|
|
|
if default:
|
|
return default.model_name
|
|
|
|
# If no default is set, return first active model of this type
|
|
first_active = AIModelConfig.objects.filter(
|
|
model_type=model_type,
|
|
is_active=True
|
|
).order_by('model_name').first()
|
|
|
|
if first_active:
|
|
return first_active.model_name
|
|
|
|
except Exception as e:
|
|
logger.error(f"Could not get default {model_type} model from DB: {e}")
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def list_models(cls, model_type: Optional[str] = None, provider: Optional[str] = None) -> list:
|
|
"""
|
|
List all available models from database, optionally filtered by type or provider.
|
|
|
|
Args:
|
|
model_type: Filter by 'text', 'image', or 'embedding'
|
|
provider: Filter by 'openai', 'anthropic', 'runware', etc.
|
|
|
|
Returns:
|
|
List of AIModelConfig instances
|
|
"""
|
|
try:
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
queryset = AIModelConfig.objects.filter(is_active=True)
|
|
|
|
if model_type:
|
|
queryset = queryset.filter(model_type=model_type)
|
|
if provider:
|
|
queryset = queryset.filter(provider=provider)
|
|
|
|
return list(queryset.order_by('model_name'))
|
|
except Exception as e:
|
|
logger.error(f"Could not list models from DB: {e}")
|
|
return []
|
|
|
|
@classmethod
|
|
def clear_cache(cls, model_id: Optional[str] = None):
|
|
"""
|
|
Clear model cache.
|
|
|
|
Args:
|
|
model_id: Clear specific model cache, or all if None
|
|
"""
|
|
if model_id:
|
|
cache.delete(cls._get_cache_key(model_id))
|
|
else:
|
|
# Clear all model caches - use pattern if available
|
|
try:
|
|
from django.core.cache import caches
|
|
default_cache = caches['default']
|
|
if hasattr(default_cache, 'delete_pattern'):
|
|
default_cache.delete_pattern(f"{CACHE_KEY_PREFIX}*")
|
|
else:
|
|
# Fallback: clear all known models from DB
|
|
from igny8_core.business.billing.models import AIModelConfig
|
|
for model in AIModelConfig.objects.values_list('model_name', flat=True):
|
|
cache.delete(cls._get_cache_key(model))
|
|
except Exception as e:
|
|
logger.warning(f"Could not clear all model caches: {e}")
|
|
|
|
@classmethod
|
|
def validate_model(cls, model_id: str) -> bool:
|
|
"""
|
|
Check if a model ID is valid and active.
|
|
|
|
Args:
|
|
model_id: The model identifier to validate
|
|
|
|
Returns:
|
|
True if model exists and is active, False otherwise
|
|
"""
|
|
model = cls.get_model(model_id)
|
|
if not model:
|
|
return False
|
|
return model.is_active
|
|
|
|
# ========== IntegrationProvider methods ==========
|
|
|
|
@classmethod
|
|
def get_provider(cls, provider_id: str) -> Optional[Any]:
|
|
"""
|
|
Get IntegrationProvider by provider_id.
|
|
|
|
Args:
|
|
provider_id: The provider identifier (e.g., 'openai', 'stripe', 'resend')
|
|
|
|
Returns:
|
|
IntegrationProvider instance, None if not found
|
|
"""
|
|
cache_key = cls._get_provider_cache_key(provider_id)
|
|
|
|
# Try cache first
|
|
cached = cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
try:
|
|
from igny8_core.modules.system.models import IntegrationProvider
|
|
provider = IntegrationProvider.objects.filter(
|
|
provider_id=provider_id,
|
|
is_active=True
|
|
).first()
|
|
|
|
if provider:
|
|
cache.set(cache_key, provider, MODEL_CACHE_TTL)
|
|
return provider
|
|
except Exception as e:
|
|
logger.error(f"Could not fetch provider {provider_id} from DB: {e}")
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def get_api_key(cls, provider_id: str) -> Optional[str]:
|
|
"""
|
|
Get API key for a provider.
|
|
|
|
Args:
|
|
provider_id: The provider identifier (e.g., 'openai', 'anthropic', 'runware')
|
|
|
|
Returns:
|
|
API key string, None if not found or provider is inactive
|
|
"""
|
|
provider = cls.get_provider(provider_id)
|
|
if provider and provider.api_key:
|
|
return provider.api_key
|
|
return None
|
|
|
|
@classmethod
|
|
def get_api_secret(cls, provider_id: str) -> Optional[str]:
|
|
"""
|
|
Get API secret for a provider (for OAuth, Stripe secret key, etc.).
|
|
|
|
Args:
|
|
provider_id: The provider identifier
|
|
|
|
Returns:
|
|
API secret string, None if not found
|
|
"""
|
|
provider = cls.get_provider(provider_id)
|
|
if provider and provider.api_secret:
|
|
return provider.api_secret
|
|
return None
|
|
|
|
@classmethod
|
|
def get_webhook_secret(cls, provider_id: str) -> Optional[str]:
|
|
"""
|
|
Get webhook secret for a provider (for Stripe, PayPal webhooks).
|
|
|
|
Args:
|
|
provider_id: The provider identifier
|
|
|
|
Returns:
|
|
Webhook secret string, None if not found
|
|
"""
|
|
provider = cls.get_provider(provider_id)
|
|
if provider and provider.webhook_secret:
|
|
return provider.webhook_secret
|
|
return None
|
|
|
|
@classmethod
|
|
def clear_provider_cache(cls, provider_id: Optional[str] = None):
|
|
"""
|
|
Clear provider cache.
|
|
|
|
Args:
|
|
provider_id: Clear specific provider cache, or all if None
|
|
"""
|
|
if provider_id:
|
|
cache.delete(cls._get_provider_cache_key(provider_id))
|
|
else:
|
|
try:
|
|
from django.core.cache import caches
|
|
default_cache = caches['default']
|
|
if hasattr(default_cache, 'delete_pattern'):
|
|
default_cache.delete_pattern(f"{PROVIDER_CACHE_PREFIX}*")
|
|
else:
|
|
from igny8_core.modules.system.models import IntegrationProvider
|
|
for pid in IntegrationProvider.objects.values_list('provider_id', flat=True):
|
|
cache.delete(cls._get_provider_cache_key(pid))
|
|
except Exception as e:
|
|
logger.warning(f"Could not clear provider caches: {e}")
|