django admin Groups reorg, Frontend udpates for site settings, #Migration runs
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
Model Registry Service
|
||||
Central registry for AI model configurations with caching.
|
||||
Replaces hardcoded MODEL_RATES and IMAGE_MODEL_RATES from constants.py
|
||||
|
||||
This service provides:
|
||||
- Database-driven model configuration (from AIModelConfig)
|
||||
- Fallback to constants.py for backward compatibility
|
||||
- Integration provider API key retrieval (from IntegrationProvider)
|
||||
- Caching for performance
|
||||
- Cost calculation methods
|
||||
|
||||
@@ -20,6 +19,9 @@ Usage:
|
||||
|
||||
# Calculate cost
|
||||
cost = ModelRegistry.calculate_cost('gpt-4o-mini', input_tokens=1000, output_tokens=500)
|
||||
|
||||
# Get API key for a provider
|
||||
api_key = ModelRegistry.get_api_key('openai')
|
||||
"""
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
@@ -33,12 +35,14 @@ MODEL_CACHE_TTL = 300
|
||||
|
||||
# Cache key prefix
|
||||
CACHE_KEY_PREFIX = 'ai_model_'
|
||||
PROVIDER_CACHE_PREFIX = 'provider_'
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
Central registry for AI model configurations with caching.
|
||||
Uses AIModelConfig from database with fallback to constants.py
|
||||
Uses AIModelConfig from database for model configs.
|
||||
Uses IntegrationProvider for API keys.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -46,6 +50,11 @@ class ModelRegistry:
|
||||
"""Generate cache key for model"""
|
||||
return f"{CACHE_KEY_PREFIX}{model_id}"
|
||||
|
||||
@classmethod
|
||||
def _get_provider_cache_key(cls, provider_id: str) -> str:
|
||||
"""Generate cache key for provider"""
|
||||
return f"{PROVIDER_CACHE_PREFIX}{provider_id}"
|
||||
|
||||
@classmethod
|
||||
def _get_from_db(cls, model_id: str) -> Optional[Any]:
|
||||
"""Get model config from database"""
|
||||
@@ -59,46 +68,6 @@ class ModelRegistry:
|
||||
logger.debug(f"Could not fetch model {model_id} from DB: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_from_constants(cls, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get model config from constants.py as fallback.
|
||||
Returns a dict mimicking AIModelConfig attributes.
|
||||
"""
|
||||
from igny8_core.ai.constants import MODEL_RATES, IMAGE_MODEL_RATES
|
||||
|
||||
# Check text models first
|
||||
if model_id in MODEL_RATES:
|
||||
rates = MODEL_RATES[model_id]
|
||||
return {
|
||||
'model_name': model_id,
|
||||
'display_name': model_id,
|
||||
'model_type': 'text',
|
||||
'provider': 'openai',
|
||||
'input_cost_per_1m': Decimal(str(rates.get('input', 0))),
|
||||
'output_cost_per_1m': Decimal(str(rates.get('output', 0))),
|
||||
'cost_per_image': None,
|
||||
'is_active': True,
|
||||
'_from_constants': True
|
||||
}
|
||||
|
||||
# Check image models
|
||||
if model_id in IMAGE_MODEL_RATES:
|
||||
cost = IMAGE_MODEL_RATES[model_id]
|
||||
return {
|
||||
'model_name': model_id,
|
||||
'display_name': model_id,
|
||||
'model_type': 'image',
|
||||
'provider': 'openai' if 'dall-e' in model_id else 'runware',
|
||||
'input_cost_per_1m': None,
|
||||
'output_cost_per_1m': None,
|
||||
'cost_per_image': Decimal(str(cost)),
|
||||
'is_active': True,
|
||||
'_from_constants': True
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model_id: str) -> Optional[Any]:
|
||||
"""
|
||||
@@ -107,13 +76,12 @@ class ModelRegistry:
|
||||
Order of lookup:
|
||||
1. Cache
|
||||
2. Database (AIModelConfig)
|
||||
3. constants.py fallback
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (e.g., 'gpt-4o-mini', 'dall-e-3')
|
||||
|
||||
Returns:
|
||||
AIModelConfig instance or dict with model config, None if not found
|
||||
AIModelConfig instance, None if not found
|
||||
"""
|
||||
cache_key = cls._get_cache_key(model_id)
|
||||
|
||||
@@ -129,13 +97,7 @@ class ModelRegistry:
|
||||
cache.set(cache_key, model_config, MODEL_CACHE_TTL)
|
||||
return model_config
|
||||
|
||||
# Fallback to constants
|
||||
fallback = cls._get_from_constants(model_id)
|
||||
if fallback:
|
||||
cache.set(cache_key, fallback, MODEL_CACHE_TTL)
|
||||
return fallback
|
||||
|
||||
logger.warning(f"Model {model_id} not found in DB or constants")
|
||||
logger.warning(f"Model {model_id} not found in database")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -154,16 +116,6 @@ class ModelRegistry:
|
||||
if not model:
|
||||
return Decimal('0')
|
||||
|
||||
# Handle dict (from constants fallback)
|
||||
if isinstance(model, dict):
|
||||
if rate_type == 'input':
|
||||
return model.get('input_cost_per_1m') or Decimal('0')
|
||||
elif rate_type == 'output':
|
||||
return model.get('output_cost_per_1m') or Decimal('0')
|
||||
elif rate_type == 'image':
|
||||
return model.get('cost_per_image') or Decimal('0')
|
||||
return Decimal('0')
|
||||
|
||||
# Handle AIModelConfig instance
|
||||
if rate_type == 'input':
|
||||
return model.input_cost_per_1m or Decimal('0')
|
||||
@@ -195,8 +147,8 @@ class ModelRegistry:
|
||||
if not model:
|
||||
return Decimal('0')
|
||||
|
||||
# Determine model type
|
||||
model_type = model.get('model_type') if isinstance(model, dict) else model.model_type
|
||||
# Get model type from AIModelConfig
|
||||
model_type = model.model_type
|
||||
|
||||
if model_type == 'text':
|
||||
input_rate = cls.get_rate(model_id, 'input')
|
||||
@@ -218,7 +170,7 @@ class ModelRegistry:
|
||||
@classmethod
|
||||
def get_default_model(cls, model_type: str = 'text') -> Optional[str]:
|
||||
"""
|
||||
Get the default model for a given type.
|
||||
Get the default model for a given type from database.
|
||||
|
||||
Args:
|
||||
model_type: 'text' or 'image'
|
||||
@@ -236,32 +188,33 @@ class ModelRegistry:
|
||||
|
||||
if default:
|
||||
return default.model_name
|
||||
|
||||
# If no default is set, return first active model of this type
|
||||
first_active = AIModelConfig.objects.filter(
|
||||
model_type=model_type,
|
||||
is_active=True
|
||||
).order_by('model_name').first()
|
||||
|
||||
if first_active:
|
||||
return first_active.model_name
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get default {model_type} model from DB: {e}")
|
||||
|
||||
# Fallback to constants
|
||||
from igny8_core.ai.constants import DEFAULT_AI_MODEL
|
||||
if model_type == 'text':
|
||||
return DEFAULT_AI_MODEL
|
||||
elif model_type == 'image':
|
||||
return 'dall-e-3'
|
||||
logger.error(f"Could not get default {model_type} model from DB: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list_models(cls, model_type: Optional[str] = None, provider: Optional[str] = None) -> list:
|
||||
"""
|
||||
List all available models, optionally filtered by type or provider.
|
||||
List all available models from database, optionally filtered by type or provider.
|
||||
|
||||
Args:
|
||||
model_type: Filter by 'text', 'image', or 'embedding'
|
||||
provider: Filter by 'openai', 'anthropic', 'runware', etc.
|
||||
|
||||
Returns:
|
||||
List of model configs
|
||||
List of AIModelConfig instances
|
||||
"""
|
||||
models = []
|
||||
|
||||
try:
|
||||
from igny8_core.business.billing.models import AIModelConfig
|
||||
queryset = AIModelConfig.objects.filter(is_active=True)
|
||||
@@ -271,27 +224,10 @@ class ModelRegistry:
|
||||
if provider:
|
||||
queryset = queryset.filter(provider=provider)
|
||||
|
||||
models = list(queryset.order_by('sort_order', 'model_name'))
|
||||
return list(queryset.order_by('model_name'))
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not list models from DB: {e}")
|
||||
|
||||
# Add models from constants if not in DB
|
||||
if not models:
|
||||
from igny8_core.ai.constants import MODEL_RATES, IMAGE_MODEL_RATES
|
||||
|
||||
if model_type in (None, 'text'):
|
||||
for model_id in MODEL_RATES:
|
||||
fallback = cls._get_from_constants(model_id)
|
||||
if fallback:
|
||||
models.append(fallback)
|
||||
|
||||
if model_type in (None, 'image'):
|
||||
for model_id in IMAGE_MODEL_RATES:
|
||||
fallback = cls._get_from_constants(model_id)
|
||||
if fallback:
|
||||
models.append(fallback)
|
||||
|
||||
return models
|
||||
logger.error(f"Could not list models from DB: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls, model_id: Optional[str] = None):
|
||||
@@ -311,10 +247,10 @@ class ModelRegistry:
|
||||
if hasattr(default_cache, 'delete_pattern'):
|
||||
default_cache.delete_pattern(f"{CACHE_KEY_PREFIX}*")
|
||||
else:
|
||||
# Fallback: clear known models
|
||||
from igny8_core.ai.constants import MODEL_RATES, IMAGE_MODEL_RATES
|
||||
for model_id in list(MODEL_RATES.keys()) + list(IMAGE_MODEL_RATES.keys()):
|
||||
cache.delete(cls._get_cache_key(model_id))
|
||||
# Fallback: clear all known models from DB
|
||||
from igny8_core.business.billing.models import AIModelConfig
|
||||
for model in AIModelConfig.objects.values_list('model_name', flat=True):
|
||||
cache.delete(cls._get_cache_key(model))
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear all model caches: {e}")
|
||||
|
||||
@@ -332,8 +268,110 @@ class ModelRegistry:
|
||||
model = cls.get_model(model_id)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
# Check if active
|
||||
if isinstance(model, dict):
|
||||
return model.get('is_active', True)
|
||||
return model.is_active
|
||||
|
||||
# ========== IntegrationProvider methods ==========
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_id: str) -> Optional[Any]:
|
||||
"""
|
||||
Get IntegrationProvider by provider_id.
|
||||
|
||||
Args:
|
||||
provider_id: The provider identifier (e.g., 'openai', 'stripe', 'resend')
|
||||
|
||||
Returns:
|
||||
IntegrationProvider instance, None if not found
|
||||
"""
|
||||
cache_key = cls._get_provider_cache_key(provider_id)
|
||||
|
||||
# Try cache first
|
||||
cached = cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from igny8_core.modules.system.models import IntegrationProvider
|
||||
provider = IntegrationProvider.objects.filter(
|
||||
provider_id=provider_id,
|
||||
is_active=True
|
||||
).first()
|
||||
|
||||
if provider:
|
||||
cache.set(cache_key, provider, MODEL_CACHE_TTL)
|
||||
return provider
|
||||
except Exception as e:
|
||||
logger.error(f"Could not fetch provider {provider_id} from DB: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_api_key(cls, provider_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get API key for a provider.
|
||||
|
||||
Args:
|
||||
provider_id: The provider identifier (e.g., 'openai', 'anthropic', 'runware')
|
||||
|
||||
Returns:
|
||||
API key string, None if not found or provider is inactive
|
||||
"""
|
||||
provider = cls.get_provider(provider_id)
|
||||
if provider and provider.api_key:
|
||||
return provider.api_key
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_api_secret(cls, provider_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get API secret for a provider (for OAuth, Stripe secret key, etc.).
|
||||
|
||||
Args:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
API secret string, None if not found
|
||||
"""
|
||||
provider = cls.get_provider(provider_id)
|
||||
if provider and provider.api_secret:
|
||||
return provider.api_secret
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_webhook_secret(cls, provider_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get webhook secret for a provider (for Stripe, PayPal webhooks).
|
||||
|
||||
Args:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
Webhook secret string, None if not found
|
||||
"""
|
||||
provider = cls.get_provider(provider_id)
|
||||
if provider and provider.webhook_secret:
|
||||
return provider.webhook_secret
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def clear_provider_cache(cls, provider_id: Optional[str] = None):
|
||||
"""
|
||||
Clear provider cache.
|
||||
|
||||
Args:
|
||||
provider_id: Clear specific provider cache, or all if None
|
||||
"""
|
||||
if provider_id:
|
||||
cache.delete(cls._get_provider_cache_key(provider_id))
|
||||
else:
|
||||
try:
|
||||
from django.core.cache import caches
|
||||
default_cache = caches['default']
|
||||
if hasattr(default_cache, 'delete_pattern'):
|
||||
default_cache.delete_pattern(f"{PROVIDER_CACHE_PREFIX}*")
|
||||
else:
|
||||
from igny8_core.modules.system.models import IntegrationProvider
|
||||
for pid in IntegrationProvider.objects.values_list('provider_id', flat=True):
|
||||
cache.delete(cls._get_provider_cache_key(pid))
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear provider caches: {e}")
|
||||
|
||||
Reference in New Issue
Block a user