Stage 1 & 2 refactor of AI engine
This commit is contained in:
@@ -1,4 +1,17 @@
|
||||
"""
|
||||
AI Function implementations
|
||||
"""
|
||||
from igny8_core.ai.functions.auto_cluster import AutoClusterFunction
|
||||
from igny8_core.ai.functions.generate_ideas import GenerateIdeasFunction, generate_ideas_core
|
||||
from igny8_core.ai.functions.generate_content import GenerateContentFunction, generate_content_core
|
||||
from igny8_core.ai.functions.generate_images import GenerateImagesFunction, generate_images_core
|
||||
|
||||
__all__ = [
|
||||
'AutoClusterFunction',
|
||||
'GenerateIdeasFunction',
|
||||
'generate_ideas_core',
|
||||
'GenerateContentFunction',
|
||||
'generate_content_core',
|
||||
'GenerateImagesFunction',
|
||||
'generate_images_core',
|
||||
]
|
||||
|
||||
@@ -7,6 +7,7 @@ from django.db import transaction
|
||||
from igny8_core.ai.base import BaseAIFunction
|
||||
from igny8_core.modules.planner.models import Keywords, Clusters
|
||||
from igny8_core.modules.system.utils import get_prompt_value
|
||||
from igny8_core.ai.ai_core import AICore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,49 +37,23 @@ class AutoClusterFunction(BaseAIFunction):
|
||||
|
||||
def validate(self, payload: dict, account=None) -> Dict:
|
||||
"""Custom validation for clustering with plan limit checks"""
|
||||
result = super().validate(payload, account)
|
||||
from igny8_core.ai.validators import validate_ids, validate_keywords_exist, validate_cluster_limits
|
||||
|
||||
# Base validation
|
||||
result = validate_ids(payload, max_items=self.get_max_items())
|
||||
if not result['valid']:
|
||||
return result
|
||||
|
||||
# Additional validation: check keywords exist
|
||||
# Check keywords exist
|
||||
ids = payload.get('ids', [])
|
||||
queryset = Keywords.objects.filter(id__in=ids)
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
keywords_result = validate_keywords_exist(ids, account)
|
||||
if not keywords_result['valid']:
|
||||
return keywords_result
|
||||
|
||||
if queryset.count() == 0:
|
||||
return {'valid': False, 'error': 'No keywords found'}
|
||||
|
||||
# Plan limit validation
|
||||
if account:
|
||||
plan = getattr(account, 'plan', None)
|
||||
if plan:
|
||||
from django.utils import timezone
|
||||
from igny8_core.modules.planner.models import Clusters
|
||||
|
||||
# Check daily cluster limit
|
||||
now = timezone.now()
|
||||
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
clusters_today = Clusters.objects.filter(
|
||||
account=account,
|
||||
created_at__gte=start_of_day
|
||||
).count()
|
||||
|
||||
if plan.daily_cluster_limit and clusters_today >= plan.daily_cluster_limit:
|
||||
return {
|
||||
'valid': False,
|
||||
'error': f'Daily cluster limit reached ({plan.daily_cluster_limit} clusters per day). Please try again tomorrow.'
|
||||
}
|
||||
|
||||
# Check max clusters limit
|
||||
total_clusters = Clusters.objects.filter(account=account).count()
|
||||
if plan.max_clusters and total_clusters >= plan.max_clusters:
|
||||
return {
|
||||
'valid': False,
|
||||
'error': f'Maximum cluster limit reached ({plan.max_clusters} clusters). Please upgrade your plan or delete existing clusters.'
|
||||
}
|
||||
else:
|
||||
return {'valid': False, 'error': 'Account does not have an active plan'}
|
||||
# Check plan limits
|
||||
limit_result = validate_cluster_limits(account, operation_type='cluster')
|
||||
if not limit_result['valid']:
|
||||
return limit_result
|
||||
|
||||
return {'valid': True}
|
||||
|
||||
@@ -158,7 +133,7 @@ class AutoClusterFunction(BaseAIFunction):
|
||||
def parse_response(self, response: str, step_tracker=None) -> List[Dict]:
|
||||
"""Parse AI response into cluster data"""
|
||||
import json
|
||||
from igny8_core.ai.processor import AIProcessor
|
||||
from igny8_core.ai.ai_core import AICore
|
||||
|
||||
if not response or not response.strip():
|
||||
error_msg = "Empty response from AI"
|
||||
@@ -172,8 +147,8 @@ class AutoClusterFunction(BaseAIFunction):
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"parse_response: Direct JSON parse failed: {e}, trying extract_json method")
|
||||
# Fall back to extract_json method which handles markdown code blocks
|
||||
processor = AIProcessor()
|
||||
json_data = processor.extract_json(response)
|
||||
ai_core = AICore(account=getattr(self, 'account', None))
|
||||
json_data = ai_core.extract_json(response)
|
||||
|
||||
if not json_data:
|
||||
error_msg = f"Failed to parse clustering response. Response: {response[:200]}..."
|
||||
|
||||
263
backend/igny8_core/ai/functions/generate_content.py
Normal file
263
backend/igny8_core/ai/functions/generate_content.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Generate Content AI Function
|
||||
Extracted from modules/writer/tasks.py
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Any
|
||||
from django.db import transaction
|
||||
from igny8_core.ai.base import BaseAIFunction
|
||||
from igny8_core.modules.writer.models import Tasks
|
||||
from igny8_core.modules.system.utils import get_prompt_value, get_default_prompt
|
||||
from igny8_core.ai.ai_core import AICore
|
||||
from igny8_core.ai.validators import validate_tasks_exist
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateContentFunction(BaseAIFunction):
|
||||
"""Generate content for tasks using AI"""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return 'generate_content'
|
||||
|
||||
def get_metadata(self) -> Dict:
|
||||
return {
|
||||
'display_name': 'Generate Content',
|
||||
'description': 'Generate article content from task ideas',
|
||||
'phases': {
|
||||
'INIT': 'Initializing content generation...',
|
||||
'PREP': 'Loading tasks and building prompts...',
|
||||
'AI_CALL': 'Generating content with AI...',
|
||||
'PARSE': 'Processing content...',
|
||||
'SAVE': 'Saving content...',
|
||||
'DONE': 'Content generated!'
|
||||
}
|
||||
}
|
||||
|
||||
def get_max_items(self) -> int:
|
||||
return 50 # Max tasks per batch
|
||||
|
||||
def validate(self, payload: dict, account=None) -> Dict:
|
||||
"""Validate task IDs"""
|
||||
result = super().validate(payload, account)
|
||||
if not result['valid']:
|
||||
return result
|
||||
|
||||
# Check tasks exist
|
||||
task_ids = payload.get('ids', [])
|
||||
if task_ids:
|
||||
task_result = validate_tasks_exist(task_ids, account)
|
||||
if not task_result['valid']:
|
||||
return task_result
|
||||
|
||||
return {'valid': True}
|
||||
|
||||
def prepare(self, payload: dict, account=None) -> List:
|
||||
"""Load tasks with all relationships"""
|
||||
task_ids = payload.get('ids', [])
|
||||
|
||||
queryset = Tasks.objects.filter(id__in=task_ids)
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
|
||||
# Preload all relationships to avoid N+1 queries
|
||||
tasks = list(queryset.select_related(
|
||||
'account', 'site', 'sector', 'cluster', 'idea'
|
||||
))
|
||||
|
||||
if not tasks:
|
||||
raise ValueError("No tasks found")
|
||||
|
||||
return tasks
|
||||
|
||||
def build_prompt(self, data: Any, account=None) -> str:
|
||||
"""Build content generation prompt for a single task"""
|
||||
if isinstance(data, list):
|
||||
# For now, handle single task (will be called per task)
|
||||
if not data:
|
||||
raise ValueError("No tasks provided")
|
||||
task = data[0]
|
||||
else:
|
||||
task = data
|
||||
|
||||
# Get prompt template
|
||||
prompt_template = get_prompt_value(account or task.account, 'content_generation')
|
||||
if not prompt_template:
|
||||
prompt_template = get_default_prompt('content_generation')
|
||||
|
||||
# Build idea data string
|
||||
idea_data = f"Title: {task.title or 'Untitled'}\n"
|
||||
if task.description:
|
||||
idea_data += f"Description: {task.description}\n"
|
||||
|
||||
# Handle idea description (might be JSON or plain text)
|
||||
if task.idea and task.idea.description:
|
||||
description = task.idea.description
|
||||
try:
|
||||
import json
|
||||
parsed_desc = json.loads(description)
|
||||
if isinstance(parsed_desc, dict):
|
||||
formatted_desc = "Content Outline:\n\n"
|
||||
if 'H2' in parsed_desc:
|
||||
for h2_section in parsed_desc['H2']:
|
||||
formatted_desc += f"## {h2_section.get('heading', '')}\n"
|
||||
if 'subsections' in h2_section:
|
||||
for h3_section in h2_section['subsections']:
|
||||
formatted_desc += f"### {h3_section.get('subheading', '')}\n"
|
||||
formatted_desc += f"Content Type: {h3_section.get('content_type', '')}\n"
|
||||
formatted_desc += f"Details: {h3_section.get('details', '')}\n\n"
|
||||
description = formatted_desc
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass # Use as plain text
|
||||
|
||||
idea_data += f"Outline: {description}\n"
|
||||
|
||||
if task.idea:
|
||||
idea_data += f"Structure: {task.idea.content_structure or task.content_structure or 'blog_post'}\n"
|
||||
idea_data += f"Type: {task.idea.content_type or task.content_type or 'blog_post'}\n"
|
||||
if task.idea.estimated_word_count:
|
||||
idea_data += f"Estimated Word Count: {task.idea.estimated_word_count}\n"
|
||||
|
||||
# Build cluster data string
|
||||
cluster_data = ''
|
||||
if task.cluster:
|
||||
cluster_data = f"Cluster Name: {task.cluster.name or ''}\n"
|
||||
if task.cluster.description:
|
||||
cluster_data += f"Description: {task.cluster.description}\n"
|
||||
cluster_data += f"Status: {task.cluster.status or 'active'}\n"
|
||||
|
||||
# Build keywords string
|
||||
keywords_data = task.keywords or ''
|
||||
if not keywords_data and task.idea:
|
||||
keywords_data = task.idea.target_keywords or ''
|
||||
|
||||
# Replace placeholders
|
||||
prompt = prompt_template.replace('[IGNY8_IDEA]', idea_data)
|
||||
prompt = prompt.replace('[IGNY8_CLUSTER]', cluster_data)
|
||||
prompt = prompt.replace('[IGNY8_KEYWORDS]', keywords_data)
|
||||
|
||||
return prompt
|
||||
|
||||
def parse_response(self, response: str, step_tracker=None) -> str:
|
||||
"""Parse and normalize content response"""
|
||||
# Content is already text, just normalize it
|
||||
try:
|
||||
from igny8_core.utils.content_normalizer import normalize_content
|
||||
normalized = normalize_content(response)
|
||||
return normalized['normalized_content']
|
||||
except Exception as e:
|
||||
logger.warning(f"Content normalization failed: {e}, using original")
|
||||
return response
|
||||
|
||||
def save_output(
|
||||
self,
|
||||
parsed: str,
|
||||
original_data: Any,
|
||||
account=None,
|
||||
progress_tracker=None,
|
||||
step_tracker=None
|
||||
) -> Dict:
|
||||
"""Save content to task"""
|
||||
if isinstance(original_data, list):
|
||||
task = original_data[0] if original_data else None
|
||||
else:
|
||||
task = original_data
|
||||
|
||||
if not task:
|
||||
raise ValueError("No task provided for saving")
|
||||
|
||||
# Calculate word count
|
||||
text_for_counting = re.sub(r'<[^>]+>', '', parsed)
|
||||
word_count = len(text_for_counting.split())
|
||||
|
||||
# Update task
|
||||
task.content = parsed
|
||||
task.word_count = word_count
|
||||
task.meta_title = task.title
|
||||
task.meta_description = (task.description or '')[:160]
|
||||
task.status = 'draft'
|
||||
task.save()
|
||||
|
||||
return {
|
||||
'count': 1,
|
||||
'tasks_updated': 1,
|
||||
'word_count': word_count
|
||||
}
|
||||
|
||||
|
||||
def generate_content_core(task_ids: List[int], account_id: int = None, progress_callback=None):
|
||||
"""
|
||||
Core logic for generating content (legacy function signature for backward compatibility).
|
||||
Can be called with or without Celery.
|
||||
|
||||
Args:
|
||||
task_ids: List of task IDs
|
||||
account_id: Account ID for account isolation
|
||||
progress_callback: Optional function to call for progress updates
|
||||
|
||||
Returns:
|
||||
Dict with 'success', 'tasks_updated', 'message', etc.
|
||||
"""
|
||||
try:
|
||||
from igny8_core.auth.models import Account
|
||||
|
||||
account = None
|
||||
if account_id:
|
||||
account = Account.objects.get(id=account_id)
|
||||
|
||||
# Use the new function class
|
||||
fn = GenerateContentFunction()
|
||||
fn.account = account
|
||||
|
||||
# Prepare payload
|
||||
payload = {'ids': task_ids}
|
||||
|
||||
# Validate
|
||||
validated = fn.validate(payload, account)
|
||||
if not validated['valid']:
|
||||
return {'success': False, 'error': validated['error']}
|
||||
|
||||
# Prepare data
|
||||
tasks = fn.prepare(payload, account)
|
||||
|
||||
tasks_updated = 0
|
||||
|
||||
# Process each task
|
||||
for task in tasks:
|
||||
# Build prompt for this task
|
||||
prompt = fn.build_prompt([task], account)
|
||||
|
||||
# Call AI using centralized request handler
|
||||
ai_core = AICore(account=account)
|
||||
result = ai_core.run_ai_request(
|
||||
prompt=prompt,
|
||||
max_tokens=4000,
|
||||
function_name='generate_content'
|
||||
)
|
||||
|
||||
if result.get('error'):
|
||||
logger.error(f"AI error for task {task.id}: {result['error']}")
|
||||
continue
|
||||
|
||||
# Parse response
|
||||
content = fn.parse_response(result['content'])
|
||||
|
||||
if not content:
|
||||
logger.warning(f"No content generated for task {task.id}")
|
||||
continue
|
||||
|
||||
# Save output
|
||||
save_result = fn.save_output(content, [task], account)
|
||||
tasks_updated += save_result.get('tasks_updated', 0)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'tasks_updated': tasks_updated,
|
||||
'message': f'Content generation complete: {tasks_updated} articles generated'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_content_core: {str(e)}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
256
backend/igny8_core/ai/functions/generate_ideas.py
Normal file
256
backend/igny8_core/ai/functions/generate_ideas.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Generate Ideas AI Function
|
||||
Extracted from modules/planner/tasks.py
|
||||
"""
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, List, Any
|
||||
from django.db import transaction
|
||||
from igny8_core.ai.base import BaseAIFunction
|
||||
from igny8_core.modules.planner.models import Clusters, ContentIdeas
|
||||
from igny8_core.modules.system.utils import get_prompt_value
|
||||
from igny8_core.ai.ai_core import AICore
|
||||
from igny8_core.ai.validators import validate_cluster_exists, validate_cluster_limits
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateIdeasFunction(BaseAIFunction):
|
||||
"""Generate content ideas from clusters using AI"""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return 'generate_ideas'
|
||||
|
||||
def get_metadata(self) -> Dict:
|
||||
return {
|
||||
'display_name': 'Generate Ideas',
|
||||
'description': 'Generate SEO-optimized content ideas from keyword clusters',
|
||||
'phases': {
|
||||
'INIT': 'Initializing idea generation...',
|
||||
'PREP': 'Loading clusters...',
|
||||
'AI_CALL': 'Generating ideas with AI...',
|
||||
'PARSE': 'Parsing idea data...',
|
||||
'SAVE': 'Saving ideas...',
|
||||
'DONE': 'Ideas generated!'
|
||||
}
|
||||
}
|
||||
|
||||
def get_max_items(self) -> int:
|
||||
return 10 # Max clusters per idea generation
|
||||
|
||||
def validate(self, payload: dict, account=None) -> Dict:
|
||||
"""Validate cluster IDs and plan limits"""
|
||||
result = super().validate(payload, account)
|
||||
if not result['valid']:
|
||||
return result
|
||||
|
||||
# Check cluster exists
|
||||
cluster_ids = payload.get('ids', [])
|
||||
if cluster_ids:
|
||||
cluster_id = cluster_ids[0] # For single cluster idea generation
|
||||
cluster_result = validate_cluster_exists(cluster_id, account)
|
||||
if not cluster_result['valid']:
|
||||
return cluster_result
|
||||
|
||||
# Check plan limits
|
||||
limit_result = validate_cluster_limits(account, operation_type='idea')
|
||||
if not limit_result['valid']:
|
||||
return limit_result
|
||||
|
||||
return {'valid': True}
|
||||
|
||||
def prepare(self, payload: dict, account=None) -> Dict:
|
||||
"""Load cluster with keywords"""
|
||||
cluster_ids = payload.get('ids', [])
|
||||
if not cluster_ids:
|
||||
raise ValueError("No cluster IDs provided")
|
||||
|
||||
cluster_id = cluster_ids[0] # Single cluster for now
|
||||
|
||||
queryset = Clusters.objects.filter(id=cluster_id)
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
|
||||
cluster = queryset.select_related('sector', 'account', 'site').prefetch_related('keywords').first()
|
||||
|
||||
if not cluster:
|
||||
raise ValueError("Cluster not found")
|
||||
|
||||
# Get keywords for this cluster
|
||||
from igny8_core.modules.planner.models import Keywords
|
||||
keywords = Keywords.objects.filter(cluster=cluster).values_list('keyword', flat=True)
|
||||
|
||||
# Format cluster data for AI
|
||||
cluster_data = [{
|
||||
'id': cluster.id,
|
||||
'name': cluster.name,
|
||||
'description': cluster.description or '',
|
||||
'keywords': list(keywords),
|
||||
}]
|
||||
|
||||
return {
|
||||
'cluster': cluster,
|
||||
'cluster_data': cluster_data,
|
||||
'account': account or cluster.account
|
||||
}
|
||||
|
||||
def build_prompt(self, data: Dict, account=None) -> str:
|
||||
"""Build ideas generation prompt"""
|
||||
cluster_data = data['cluster_data']
|
||||
|
||||
# Get prompt template
|
||||
prompt_template = get_prompt_value(account or data.get('account'), 'ideas')
|
||||
|
||||
# Format clusters text
|
||||
clusters_text = '\n'.join([
|
||||
f"Cluster ID: {c.get('id', '')} | Name: {c.get('name', '')} | Description: {c.get('description', '')}"
|
||||
for c in cluster_data
|
||||
])
|
||||
|
||||
# Format cluster keywords
|
||||
cluster_keywords_text = '\n'.join([
|
||||
f"Cluster ID: {c.get('id', '')} | Name: {c.get('name', '')} | Keywords: {', '.join(c.get('keywords', []))}"
|
||||
for c in cluster_data
|
||||
])
|
||||
|
||||
# Replace placeholders
|
||||
prompt = prompt_template.replace('[IGNY8_CLUSTERS]', clusters_text)
|
||||
prompt = prompt.replace('[IGNY8_CLUSTER_KEYWORDS]', cluster_keywords_text)
|
||||
|
||||
return prompt
|
||||
|
||||
def parse_response(self, response: str, step_tracker=None) -> List[Dict]:
|
||||
"""Parse AI response into idea data"""
|
||||
ai_core = AICore(account=self.account if hasattr(self, 'account') else None)
|
||||
json_data = ai_core.extract_json(response)
|
||||
|
||||
if not json_data or 'ideas' not in json_data:
|
||||
error_msg = f"Failed to parse ideas response: {response[:200]}..."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return json_data.get('ideas', [])
|
||||
|
||||
def save_output(
|
||||
self,
|
||||
parsed: List[Dict],
|
||||
original_data: Dict,
|
||||
account=None,
|
||||
progress_tracker=None,
|
||||
step_tracker=None
|
||||
) -> Dict:
|
||||
"""Save ideas to database"""
|
||||
cluster = original_data['cluster']
|
||||
account = account or original_data.get('account')
|
||||
|
||||
if not account:
|
||||
raise ValueError("Account is required for idea creation")
|
||||
|
||||
ideas_created = 0
|
||||
|
||||
with transaction.atomic():
|
||||
for idea_data in parsed:
|
||||
# Handle description - might be dict or string
|
||||
description = idea_data.get('description', '')
|
||||
if isinstance(description, dict):
|
||||
description = json.dumps(description)
|
||||
elif not isinstance(description, str):
|
||||
description = str(description)
|
||||
|
||||
# Handle target_keywords
|
||||
target_keywords = idea_data.get('covered_keywords', '') or idea_data.get('target_keywords', '')
|
||||
|
||||
# Create ContentIdeas record
|
||||
ContentIdeas.objects.create(
|
||||
idea_title=idea_data.get('title', 'Untitled Idea'),
|
||||
description=description,
|
||||
content_type=idea_data.get('content_type', 'blog_post'),
|
||||
content_structure=idea_data.get('content_structure', 'supporting_page'),
|
||||
target_keywords=target_keywords,
|
||||
keyword_cluster=cluster,
|
||||
estimated_word_count=idea_data.get('estimated_word_count', 1500),
|
||||
status='new',
|
||||
account=account,
|
||||
site=cluster.site,
|
||||
sector=cluster.sector,
|
||||
)
|
||||
ideas_created += 1
|
||||
|
||||
return {
|
||||
'count': ideas_created,
|
||||
'ideas_created': ideas_created
|
||||
}
|
||||
|
||||
|
||||
def generate_ideas_core(cluster_id: int, account_id: int = None, progress_callback=None):
|
||||
"""
|
||||
Core logic for generating ideas (legacy function signature for backward compatibility).
|
||||
Can be called with or without Celery.
|
||||
|
||||
Args:
|
||||
cluster_id: Cluster ID to generate idea for
|
||||
account_id: Account ID for account isolation
|
||||
progress_callback: Optional function to call for progress updates
|
||||
|
||||
Returns:
|
||||
Dict with 'success', 'idea_created', 'message', etc.
|
||||
"""
|
||||
try:
|
||||
from igny8_core.auth.models import Account
|
||||
|
||||
account = None
|
||||
if account_id:
|
||||
account = Account.objects.get(id=account_id)
|
||||
|
||||
# Use the new function class
|
||||
fn = GenerateIdeasFunction()
|
||||
# Store account for use in methods
|
||||
fn.account = account
|
||||
|
||||
# Prepare payload
|
||||
payload = {'ids': [cluster_id]}
|
||||
|
||||
# Validate
|
||||
validated = fn.validate(payload, account)
|
||||
if not validated['valid']:
|
||||
return {'success': False, 'error': validated['error']}
|
||||
|
||||
# Prepare data
|
||||
data = fn.prepare(payload, account)
|
||||
|
||||
# Build prompt
|
||||
prompt = fn.build_prompt(data, account)
|
||||
|
||||
# Call AI using centralized request handler
|
||||
ai_core = AICore(account=account)
|
||||
result = ai_core.run_ai_request(
|
||||
prompt=prompt,
|
||||
max_tokens=4000,
|
||||
function_name='generate_ideas'
|
||||
)
|
||||
|
||||
if result.get('error'):
|
||||
return {'success': False, 'error': result['error']}
|
||||
|
||||
# Parse response
|
||||
ideas_data = fn.parse_response(result['content'])
|
||||
|
||||
if not ideas_data:
|
||||
return {'success': False, 'error': 'No ideas generated by AI'}
|
||||
|
||||
# Take first idea
|
||||
idea_data = ideas_data[0]
|
||||
|
||||
# Save output
|
||||
save_result = fn.save_output(ideas_data, data, account)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'idea_created': save_result['ideas_created'],
|
||||
'message': f"Idea '{idea_data.get('title', 'Untitled')}' created"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_ideas_core: {str(e)}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
275
backend/igny8_core/ai/functions/generate_images.py
Normal file
275
backend/igny8_core/ai/functions/generate_images.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Generate Images AI Function
|
||||
Extracted from modules/writer/tasks.py
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
from django.db import transaction
|
||||
from igny8_core.ai.base import BaseAIFunction
|
||||
from igny8_core.modules.writer.models import Tasks, Images
|
||||
from igny8_core.modules.system.utils import get_prompt_value, get_default_prompt
|
||||
from igny8_core.ai.ai_core import AICore
|
||||
from igny8_core.ai.validators import validate_tasks_exist
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateImagesFunction(BaseAIFunction):
|
||||
"""Generate images for tasks using AI"""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return 'generate_images'
|
||||
|
||||
def get_metadata(self) -> Dict:
|
||||
return {
|
||||
'display_name': 'Generate Images',
|
||||
'description': 'Generate featured and in-article images for tasks',
|
||||
'phases': {
|
||||
'INIT': 'Initializing image generation...',
|
||||
'PREP': 'Extracting image prompts...',
|
||||
'AI_CALL': 'Generating images with AI...',
|
||||
'PARSE': 'Processing image URLs...',
|
||||
'SAVE': 'Saving images...',
|
||||
'DONE': 'Images generated!'
|
||||
}
|
||||
}
|
||||
|
||||
def get_max_items(self) -> int:
|
||||
return 20 # Max tasks per batch
|
||||
|
||||
def validate(self, payload: dict, account=None) -> Dict:
|
||||
"""Validate task IDs"""
|
||||
result = super().validate(payload, account)
|
||||
if not result['valid']:
|
||||
return result
|
||||
|
||||
# Check tasks exist
|
||||
task_ids = payload.get('ids', [])
|
||||
if task_ids:
|
||||
task_result = validate_tasks_exist(task_ids, account)
|
||||
if not task_result['valid']:
|
||||
return task_result
|
||||
|
||||
return {'valid': True}
|
||||
|
||||
def prepare(self, payload: dict, account=None) -> Dict:
|
||||
"""Load tasks and image generation settings"""
|
||||
task_ids = payload.get('ids', [])
|
||||
|
||||
queryset = Tasks.objects.filter(id__in=task_ids)
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
|
||||
tasks = list(queryset.select_related('account', 'sector', 'site'))
|
||||
|
||||
if not tasks:
|
||||
raise ValueError("No tasks found")
|
||||
|
||||
# Get image generation settings
|
||||
image_settings = {}
|
||||
if account:
|
||||
try:
|
||||
from igny8_core.modules.system.models import IntegrationSettings
|
||||
integration = IntegrationSettings.objects.get(
|
||||
account=account,
|
||||
integration_type='image_generation',
|
||||
is_active=True
|
||||
)
|
||||
image_settings = integration.config or {}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract settings with defaults
|
||||
provider = image_settings.get('provider') or image_settings.get('service', 'openai')
|
||||
if provider == 'runware':
|
||||
model = image_settings.get('model') or image_settings.get('runwareModel', 'runware:97@1')
|
||||
else:
|
||||
model = image_settings.get('model', 'dall-e-3')
|
||||
|
||||
return {
|
||||
'tasks': tasks,
|
||||
'account': account,
|
||||
'provider': provider,
|
||||
'model': model,
|
||||
'image_type': image_settings.get('image_type', 'realistic'),
|
||||
'max_in_article_images': int(image_settings.get('max_in_article_images', 2)),
|
||||
'desktop_enabled': image_settings.get('desktop_enabled', True),
|
||||
'mobile_enabled': image_settings.get('mobile_enabled', True),
|
||||
}
|
||||
|
||||
def build_prompt(self, data: Dict, account=None) -> Dict:
|
||||
"""Extract image prompts from task content"""
|
||||
task = data.get('task')
|
||||
max_images = data.get('max_in_article_images', 2)
|
||||
|
||||
if not task or not task.content:
|
||||
raise ValueError("Task has no content")
|
||||
|
||||
# Use AI to extract image prompts
|
||||
ai_core = AICore(account=account or data.get('account'))
|
||||
|
||||
# Get prompt template
|
||||
prompt_template = get_prompt_value(account or data.get('account'), 'image_prompt_extraction')
|
||||
if not prompt_template:
|
||||
prompt_template = get_default_prompt('image_prompt_extraction')
|
||||
|
||||
# Format prompt
|
||||
prompt = prompt_template.format(
|
||||
title=task.title,
|
||||
content=task.content[:5000], # Limit content length
|
||||
max_images=max_images
|
||||
)
|
||||
|
||||
# Call AI to extract prompts using centralized request handler
|
||||
result = ai_core.run_ai_request(
|
||||
prompt=prompt,
|
||||
max_tokens=1000,
|
||||
function_name='extract_image_prompts'
|
||||
)
|
||||
|
||||
if result.get('error'):
|
||||
raise ValueError(f"Failed to extract image prompts: {result['error']}")
|
||||
|
||||
# Parse JSON response
|
||||
json_data = ai_core.extract_json(result['content'])
|
||||
|
||||
if not json_data:
|
||||
raise ValueError("Failed to parse image prompts response")
|
||||
|
||||
return {
|
||||
'featured_prompt': json_data.get('featured_prompt', ''),
|
||||
'in_article_prompts': json_data.get('in_article_prompts', [])
|
||||
}
|
||||
|
||||
def parse_response(self, response: Dict, step_tracker=None) -> Dict:
|
||||
"""Parse image generation response (already parsed, just return)"""
|
||||
return response
|
||||
|
||||
def save_output(
|
||||
self,
|
||||
parsed: Dict,
|
||||
original_data: Dict,
|
||||
account=None,
|
||||
progress_tracker=None,
|
||||
step_tracker=None
|
||||
) -> Dict:
|
||||
"""Save images to database"""
|
||||
task = original_data.get('task')
|
||||
image_url = parsed.get('url')
|
||||
image_type = parsed.get('image_type') # 'featured', 'desktop', 'mobile'
|
||||
|
||||
if not task or not image_url:
|
||||
raise ValueError("Missing task or image URL")
|
||||
|
||||
# Create Images record
|
||||
image = Images.objects.create(
|
||||
task=task,
|
||||
image_url=image_url,
|
||||
image_type=image_type,
|
||||
account=account or task.account,
|
||||
site=task.site,
|
||||
sector=task.sector,
|
||||
)
|
||||
|
||||
return {
|
||||
'count': 1,
|
||||
'images_created': 1,
|
||||
'image_id': image.id
|
||||
}
|
||||
|
||||
|
||||
def generate_images_core(task_ids: List[int], account_id: int = None, progress_callback=None):
|
||||
"""
|
||||
Core logic for generating images (legacy function signature for backward compatibility).
|
||||
Can be called with or without Celery.
|
||||
|
||||
Args:
|
||||
task_ids: List of task IDs
|
||||
account_id: Account ID for account isolation
|
||||
progress_callback: Optional function to call for progress updates
|
||||
|
||||
Returns:
|
||||
Dict with 'success', 'images_created', 'message', etc.
|
||||
"""
|
||||
try:
|
||||
from igny8_core.auth.models import Account
|
||||
|
||||
account = None
|
||||
if account_id:
|
||||
account = Account.objects.get(id=account_id)
|
||||
|
||||
# Use the new function class
|
||||
fn = GenerateImagesFunction()
|
||||
fn.account = account
|
||||
|
||||
# Prepare payload
|
||||
payload = {'ids': task_ids}
|
||||
|
||||
# Validate
|
||||
validated = fn.validate(payload, account)
|
||||
if not validated['valid']:
|
||||
return {'success': False, 'error': validated['error']}
|
||||
|
||||
# Prepare data
|
||||
data = fn.prepare(payload, account)
|
||||
tasks = data['tasks']
|
||||
|
||||
# Get prompts
|
||||
image_prompt_template = get_prompt_value(account, 'image_prompt_template')
|
||||
if not image_prompt_template:
|
||||
image_prompt_template = get_default_prompt('image_prompt_template')
|
||||
|
||||
negative_prompt = get_prompt_value(account, 'negative_prompt')
|
||||
if not negative_prompt:
|
||||
negative_prompt = get_default_prompt('negative_prompt')
|
||||
|
||||
ai_core = AICore(account=account)
|
||||
images_created = 0
|
||||
|
||||
# Process each task
|
||||
for task in tasks:
|
||||
if not task.content:
|
||||
continue
|
||||
|
||||
# Extract image prompts
|
||||
prompts_data = fn.build_prompt({'task': task, **data}, account)
|
||||
featured_prompt = prompts_data['featured_prompt']
|
||||
in_article_prompts = prompts_data['in_article_prompts']
|
||||
|
||||
# Format featured prompt
|
||||
formatted_featured = image_prompt_template.format(
|
||||
image_type=data['image_type'],
|
||||
post_title=task.title,
|
||||
image_prompt=featured_prompt
|
||||
)
|
||||
|
||||
# Generate featured image using centralized handler
|
||||
featured_result = ai_core.generate_image(
|
||||
prompt=formatted_featured,
|
||||
provider=data['provider'],
|
||||
model=data['model'],
|
||||
negative_prompt=negative_prompt,
|
||||
function_name='generate_images'
|
||||
)
|
||||
|
||||
if not featured_result.get('error') and featured_result.get('url'):
|
||||
fn.save_output(
|
||||
{'url': featured_result['url'], 'image_type': 'featured'},
|
||||
{'task': task, **data},
|
||||
account
|
||||
)
|
||||
images_created += 1
|
||||
|
||||
# Generate in-article images (desktop/mobile if enabled)
|
||||
# ... (simplified for now, full logic in tasks.py)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'images_created': images_created,
|
||||
'message': f'Image generation complete: {images_created} images created'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_images_core: {str(e)}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
Reference in New Issue
Block a user