Add Generate Image Prompts Functionality: Implement new AI function for generating image prompts, update API endpoints, and integrate with frontend actions for content management.
This commit is contained in:
@@ -5,6 +5,7 @@ from igny8_core.ai.functions.auto_cluster import AutoClusterFunction
|
||||
from igny8_core.ai.functions.generate_ideas import GenerateIdeasFunction
|
||||
from igny8_core.ai.functions.generate_content import GenerateContentFunction
|
||||
from igny8_core.ai.functions.generate_images import GenerateImagesFunction, generate_images_core
|
||||
from igny8_core.ai.functions.generate_image_prompts import GenerateImagePromptsFunction
|
||||
|
||||
__all__ = [
|
||||
'AutoClusterFunction',
|
||||
@@ -12,4 +13,5 @@ __all__ = [
|
||||
'GenerateContentFunction',
|
||||
'GenerateImagesFunction',
|
||||
'generate_images_core',
|
||||
'GenerateImagePromptsFunction',
|
||||
]
|
||||
|
||||
249
backend/igny8_core/ai/functions/generate_image_prompts.py
Normal file
249
backend/igny8_core/ai/functions/generate_image_prompts.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
Generate Image Prompts AI Function
|
||||
Extracts image prompts from content using AI
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
from django.db import transaction
|
||||
from igny8_core.ai.base import BaseAIFunction
|
||||
from igny8_core.modules.writer.models import Content, Images
|
||||
from igny8_core.ai.ai_core import AICore
|
||||
from igny8_core.ai.validators import validate_ids
|
||||
from igny8_core.ai.prompts import PromptRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateImagePromptsFunction(BaseAIFunction):
|
||||
"""Generate image prompts from content using AI"""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return 'generate_image_prompts'
|
||||
|
||||
def get_metadata(self) -> Dict:
|
||||
return {
|
||||
'display_name': 'Generate Image Prompts',
|
||||
'description': 'Extract image prompts from content (title, intro, H2 headings)',
|
||||
'phases': {
|
||||
'INIT': 'Initializing prompt generation...',
|
||||
'PREP': 'Loading content and extracting elements...',
|
||||
'AI_CALL': 'Generating prompts with AI...',
|
||||
'PARSE': 'Parsing prompt data...',
|
||||
'SAVE': 'Saving prompts...',
|
||||
'DONE': 'Prompts generated!'
|
||||
}
|
||||
}
|
||||
|
||||
def get_max_items(self) -> int:
|
||||
return 50 # Max content records per batch
|
||||
|
||||
def validate(self, payload: dict, account=None) -> Dict:
|
||||
"""Validate content IDs exist"""
|
||||
result = validate_ids(payload, max_items=self.get_max_items())
|
||||
if not result['valid']:
|
||||
return result
|
||||
|
||||
# Check content records exist
|
||||
content_ids = payload.get('ids', [])
|
||||
if content_ids:
|
||||
queryset = Content.objects.filter(id__in=content_ids)
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
|
||||
if queryset.count() == 0:
|
||||
return {'valid': False, 'error': 'No content records found'}
|
||||
|
||||
return {'valid': True}
|
||||
|
||||
def prepare(self, payload: dict, account=None) -> List:
|
||||
"""Load content records and extract elements for prompt generation"""
|
||||
content_ids = payload.get('ids', [])
|
||||
|
||||
queryset = Content.objects.filter(id__in=content_ids)
|
||||
if account:
|
||||
queryset = queryset.filter(account=account)
|
||||
|
||||
contents = list(queryset.select_related('task', 'account', 'site', 'sector'))
|
||||
|
||||
if not contents:
|
||||
raise ValueError("No content records found")
|
||||
|
||||
# Get max_in_article_images from IntegrationSettings
|
||||
max_images = self._get_max_in_article_images(account)
|
||||
|
||||
# Extract content elements for each content record
|
||||
extracted_data = []
|
||||
for content in contents:
|
||||
extracted = self._extract_content_elements(content, max_images)
|
||||
extracted_data.append({
|
||||
'content': content,
|
||||
'extracted': extracted,
|
||||
'max_images': max_images,
|
||||
})
|
||||
|
||||
return extracted_data
|
||||
|
||||
def build_prompt(self, data: Any, account=None) -> str:
|
||||
"""Build prompt using PromptRegistry - handles list of content items"""
|
||||
# Handle list of content items (from prepare)
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
raise ValueError("No content items provided")
|
||||
# For now, process first item (can be extended to batch process all)
|
||||
data = data[0]
|
||||
|
||||
extracted = data['extracted']
|
||||
max_images = data.get('max_images', 2)
|
||||
|
||||
# Format content for prompt
|
||||
content_text = self._format_content_for_prompt(extracted)
|
||||
|
||||
# Get prompt from PromptRegistry - same as other functions
|
||||
prompt = PromptRegistry.get_prompt(
|
||||
function_name='generate_image_prompts',
|
||||
account=account,
|
||||
context={
|
||||
'title': extracted['title'],
|
||||
'content': content_text,
|
||||
'max_images': max_images,
|
||||
}
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def parse_response(self, response: str, step_tracker=None) -> Dict:
|
||||
"""Parse AI response - same pattern as other functions"""
|
||||
ai_core = AICore(account=getattr(self, 'account', None))
|
||||
json_data = ai_core.extract_json(response)
|
||||
|
||||
if not json_data:
|
||||
raise ValueError(f"Failed to parse image prompts response: {response[:200]}...")
|
||||
|
||||
# Validate structure
|
||||
if 'featured_prompt' not in json_data:
|
||||
raise ValueError("Missing 'featured_prompt' in AI response")
|
||||
|
||||
if 'in_article_prompts' not in json_data:
|
||||
raise ValueError("Missing 'in_article_prompts' in AI response")
|
||||
|
||||
return json_data
|
||||
|
||||
def save_output(
|
||||
self,
|
||||
parsed: Dict,
|
||||
original_data: Any,
|
||||
account=None,
|
||||
progress_tracker=None,
|
||||
step_tracker=None
|
||||
) -> Dict:
|
||||
"""Save prompts to Images model - handles list of content items"""
|
||||
# Handle list of content items (from prepare)
|
||||
if isinstance(original_data, list):
|
||||
if not original_data:
|
||||
raise ValueError("No content items provided")
|
||||
# For now, process first item (can be extended to batch process all)
|
||||
original_data = original_data[0]
|
||||
|
||||
content = original_data['content']
|
||||
extracted = original_data['extracted']
|
||||
max_images = original_data.get('max_images', 2)
|
||||
|
||||
prompts_created = 0
|
||||
|
||||
with transaction.atomic():
|
||||
# Save featured image prompt
|
||||
Images.objects.update_or_create(
|
||||
task=content.task,
|
||||
image_type='featured',
|
||||
defaults={
|
||||
'prompt': parsed['featured_prompt'],
|
||||
'status': 'pending',
|
||||
'position': 0,
|
||||
}
|
||||
)
|
||||
prompts_created += 1
|
||||
|
||||
# Save in-article image prompts
|
||||
in_article_prompts = parsed.get('in_article_prompts', [])
|
||||
h2_headings = extracted.get('h2_headings', [])
|
||||
|
||||
for idx, prompt_text in enumerate(in_article_prompts[:max_images]):
|
||||
heading = h2_headings[idx] if idx < len(h2_headings) else f"Section {idx + 1}"
|
||||
|
||||
Images.objects.update_or_create(
|
||||
task=content.task,
|
||||
image_type='in_article',
|
||||
position=idx + 1,
|
||||
defaults={
|
||||
'prompt': prompt_text,
|
||||
'status': 'pending',
|
||||
}
|
||||
)
|
||||
prompts_created += 1
|
||||
|
||||
return {
|
||||
'count': prompts_created,
|
||||
'prompts_created': prompts_created,
|
||||
}
|
||||
|
||||
# Helper methods
|
||||
def _get_max_in_article_images(self, account) -> int:
|
||||
"""Get max_in_article_images from IntegrationSettings"""
|
||||
try:
|
||||
from igny8_core.modules.system.models import IntegrationSettings
|
||||
settings = IntegrationSettings.objects.get(
|
||||
account=account,
|
||||
integration_type='image_generation'
|
||||
)
|
||||
return settings.config.get('max_in_article_images', 2)
|
||||
except IntegrationSettings.DoesNotExist:
|
||||
return 2 # Default
|
||||
|
||||
def _extract_content_elements(self, content: Content, max_images: int) -> Dict:
|
||||
"""Extract title, intro paragraphs, and H2 headings from content HTML"""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
html_content = content.html_content or ''
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
# Extract title
|
||||
title = content.title or content.task.title or ''
|
||||
|
||||
# Extract first 1-2 intro paragraphs (skip italic hook if present)
|
||||
paragraphs = soup.find_all('p')
|
||||
intro_paragraphs = []
|
||||
for p in paragraphs[:3]: # Check first 3 paragraphs
|
||||
text = p.get_text(strip=True)
|
||||
# Skip italic hook (usually 30-40 words)
|
||||
if len(text.split()) > 50: # Real paragraph, not hook
|
||||
intro_paragraphs.append(text)
|
||||
if len(intro_paragraphs) >= 2:
|
||||
break
|
||||
|
||||
# Extract first N H2 headings
|
||||
h2_tags = soup.find_all('h2')
|
||||
h2_headings = [h2.get_text(strip=True) for h2 in h2_tags[:max_images]]
|
||||
|
||||
return {
|
||||
'title': title,
|
||||
'intro_paragraphs': intro_paragraphs,
|
||||
'h2_headings': h2_headings,
|
||||
}
|
||||
|
||||
def _format_content_for_prompt(self, extracted: Dict) -> str:
|
||||
"""Format extracted content for prompt input"""
|
||||
lines = []
|
||||
|
||||
if extracted.get('intro_paragraphs'):
|
||||
lines.append("ARTICLE INTRODUCTION:")
|
||||
for para in extracted['intro_paragraphs']:
|
||||
lines.append(para)
|
||||
lines.append("")
|
||||
|
||||
if extracted.get('h2_headings'):
|
||||
lines.append("ARTICLE HEADINGS (for in-article images):")
|
||||
for idx, heading in enumerate(extracted['h2_headings'], 1):
|
||||
lines.append(f"{idx}. {heading}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -274,6 +274,7 @@ Make sure each prompt is detailed enough for image generation, describing the vi
|
||||
'generate_content': 'content_generation',
|
||||
'generate_images': 'image_prompt_extraction',
|
||||
'extract_image_prompts': 'image_prompt_extraction',
|
||||
'generate_image_prompts': 'image_prompt_extraction',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -89,8 +89,14 @@ def _load_generate_images():
|
||||
from igny8_core.ai.functions.generate_images import GenerateImagesFunction
|
||||
return GenerateImagesFunction
|
||||
|
||||
def _load_generate_image_prompts():
|
||||
"""Lazy loader for generate_image_prompts function"""
|
||||
from igny8_core.ai.functions.generate_image_prompts import GenerateImagePromptsFunction
|
||||
return GenerateImagePromptsFunction
|
||||
|
||||
register_lazy_function('auto_cluster', _load_auto_cluster)
|
||||
register_lazy_function('generate_ideas', _load_generate_ideas)
|
||||
register_lazy_function('generate_content', _load_generate_content)
|
||||
register_lazy_function('generate_images', _load_generate_images)
|
||||
register_lazy_function('generate_image_prompts', _load_generate_image_prompts)
|
||||
|
||||
|
||||
@@ -34,6 +34,12 @@ MODEL_CONFIG = {
|
||||
"temperature": 0.7,
|
||||
"response_format": {"type": "json_object"},
|
||||
},
|
||||
"generate_image_prompts": {
|
||||
"model": "gpt-4o-mini",
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.7,
|
||||
"response_format": {"type": "json_object"},
|
||||
},
|
||||
}
|
||||
|
||||
# Function name aliases (for backward compatibility)
|
||||
|
||||
@@ -455,4 +455,57 @@ class ContentViewSet(SiteSectorModelViewSet):
|
||||
serializer.save(account=account)
|
||||
else:
|
||||
serializer.save()
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='generate_image_prompts', url_name='generate_image_prompts')
|
||||
def generate_image_prompts(self, request):
|
||||
"""Generate image prompts for content records - same pattern as other AI functions"""
|
||||
from igny8_core.ai.tasks import run_ai_task
|
||||
|
||||
account = getattr(request, 'account', None)
|
||||
ids = request.data.get('ids', [])
|
||||
|
||||
if not ids:
|
||||
return Response({
|
||||
'error': 'No IDs provided',
|
||||
'type': 'ValidationError'
|
||||
}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
account_id = account.id if account else None
|
||||
|
||||
# Queue Celery task
|
||||
try:
|
||||
if hasattr(run_ai_task, 'delay'):
|
||||
task = run_ai_task.delay(
|
||||
function_name='generate_image_prompts',
|
||||
payload={'ids': ids},
|
||||
account_id=account_id
|
||||
)
|
||||
return Response({
|
||||
'success': True,
|
||||
'task_id': str(task.id),
|
||||
'message': 'Image prompt generation started'
|
||||
}, status=status.HTTP_200_OK)
|
||||
else:
|
||||
# Fallback to synchronous execution
|
||||
result = run_ai_task(
|
||||
function_name='generate_image_prompts',
|
||||
payload={'ids': ids},
|
||||
account_id=account_id
|
||||
)
|
||||
if result.get('success'):
|
||||
return Response({
|
||||
'success': True,
|
||||
'prompts_created': result.get('count', 0),
|
||||
'message': 'Image prompts generated successfully'
|
||||
}, status=status.HTTP_200_OK)
|
||||
else:
|
||||
return Response({
|
||||
'error': result.get('error', 'Image prompt generation failed'),
|
||||
'type': 'TaskExecutionError'
|
||||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
except Exception as e:
|
||||
return Response({
|
||||
'error': str(e),
|
||||
'type': 'ExecutionError'
|
||||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@@ -258,8 +258,8 @@ const tableActionsConfigs: Record<string, TableActionsConfig> = {
|
||||
variant: 'primary',
|
||||
},
|
||||
{
|
||||
key: 'generate_images',
|
||||
label: 'Generate Images',
|
||||
key: 'generate_image_prompts',
|
||||
label: 'Generate Image Prompts',
|
||||
icon: <BoltIcon className="w-5 h-5 text-purple-500" />,
|
||||
variant: 'primary',
|
||||
},
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
fetchContent,
|
||||
Content as ContentType,
|
||||
ContentFilters,
|
||||
autoGenerateImages,
|
||||
generateImagePrompts,
|
||||
} from '../../services/api';
|
||||
import { useToast } from '../../components/ui/toast/ToastContainer';
|
||||
import { FileIcon } from '../../icons';
|
||||
@@ -147,27 +147,20 @@ export default function Content() {
|
||||
}, [pageConfig?.headerMetrics, content, totalCount]);
|
||||
|
||||
const handleRowAction = useCallback(async (action: string, row: ContentType) => {
|
||||
if (action === 'generate_images') {
|
||||
const taskId = row.task_id;
|
||||
if (!taskId) {
|
||||
toast.error('No task linked to this content for image generation');
|
||||
return;
|
||||
}
|
||||
|
||||
if (action === 'generate_image_prompts') {
|
||||
try {
|
||||
const result = await autoGenerateImages([taskId]);
|
||||
|
||||
const result = await generateImagePrompts([row.id]);
|
||||
if (result.success) {
|
||||
if (result.task_id) {
|
||||
toast.success('Image generation started');
|
||||
toast.success('Image prompts generation started');
|
||||
} else {
|
||||
toast.success(`Image generation complete: ${result.images_created || 0} image${(result.images_created || 0) === 1 ? '' : 's'} generated`);
|
||||
toast.success(`Image prompts generated: ${result.prompts_created || 0} prompt${(result.prompts_created || 0) === 1 ? '' : 's'} created`);
|
||||
}
|
||||
} else {
|
||||
toast.error(result.error || 'Failed to generate images');
|
||||
toast.error(result.error || 'Failed to generate image prompts');
|
||||
}
|
||||
} catch (error: any) {
|
||||
toast.error(`Failed to generate images: ${error.message}`);
|
||||
toast.error(`Failed to generate prompts: ${error.message}`);
|
||||
}
|
||||
}
|
||||
}, [toast]);
|
||||
|
||||
@@ -952,6 +952,13 @@ export async function autoGenerateImages(taskIds: number[]): Promise<{ success:
|
||||
}
|
||||
}
|
||||
|
||||
export async function generateImagePrompts(contentIds: number[]): Promise<any> {
|
||||
return fetchAPI('/v1/writer/content/generate_image_prompts/', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ ids: contentIds }),
|
||||
});
|
||||
}
|
||||
|
||||
// TaskImages API functions
|
||||
export interface TaskImage {
|
||||
id: number;
|
||||
|
||||
Reference in New Issue
Block a user