Compare commits
2 Commits
fix/securi
...
fix/pdf-du
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2afcdf7b77 | ||
|
|
13beb98164 |
218
apps/image_rag.py
Normal file
218
apps/image_rag.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
CLIP Image RAG Application
|
||||||
|
|
||||||
|
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||||
|
You can index a directory of images and search them using text queries.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||||
|
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from apps.base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRAG(BaseRAGExample):
|
||||||
|
"""
|
||||||
|
RAG application for images using CLIP embeddings.
|
||||||
|
|
||||||
|
This class provides a complete RAG pipeline for image data, including
|
||||||
|
CLIP embedding generation, indexing, and text-based image search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Image RAG",
|
||||||
|
description="RAG application for images using CLIP embeddings",
|
||||||
|
default_index_name="image_index",
|
||||||
|
)
|
||||||
|
# Override default embedding model to use CLIP
|
||||||
|
self.embedding_model_default = "clip-ViT-L-14"
|
||||||
|
self.embedding_mode_default = "sentence-transformers"
|
||||||
|
self._image_data: list[dict] = []
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add image-specific arguments."""
|
||||||
|
image_group = parser.add_argument_group("Image Parameters")
|
||||||
|
image_group.add_argument(
|
||||||
|
"--image-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing images to index",
|
||||||
|
)
|
||||||
|
image_group.add_argument(
|
||||||
|
"--image-extensions",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||||
|
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||||
|
)
|
||||||
|
image_group.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size for CLIP embedding generation (default: 32)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||||
|
self._image_data = self._load_images_and_embeddings(args)
|
||||||
|
return [entry["text"] for entry in self._image_data]
|
||||||
|
|
||||||
|
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||||
|
"""Helper to process images and produce embeddings/metadata."""
|
||||||
|
image_dir = Path(args.image_dir)
|
||||||
|
if not image_dir.exists():
|
||||||
|
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||||
|
|
||||||
|
print(f"📸 Loading images from {image_dir}...")
|
||||||
|
|
||||||
|
# Find all image files
|
||||||
|
image_files = []
|
||||||
|
for ext in args.image_extensions:
|
||||||
|
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||||
|
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||||
|
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError(
|
||||||
|
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Found {len(image_files)} images")
|
||||||
|
|
||||||
|
# Limit if max_items is set
|
||||||
|
if args.max_items > 0:
|
||||||
|
image_files = image_files[: args.max_items]
|
||||||
|
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||||
|
|
||||||
|
# Load CLIP model
|
||||||
|
print("🔍 Loading CLIP model...")
|
||||||
|
model = SentenceTransformer(self.embedding_model_default)
|
||||||
|
|
||||||
|
# Process images and generate embeddings
|
||||||
|
print("🖼️ Processing images and generating embeddings...")
|
||||||
|
image_data = []
|
||||||
|
batch_images = []
|
||||||
|
batch_paths = []
|
||||||
|
|
||||||
|
for image_path in tqdm(image_files, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
batch_images.append(image)
|
||||||
|
batch_paths.append(image_path)
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
if len(batch_images) >= args.batch_size:
|
||||||
|
embeddings = model.encode(
|
||||||
|
batch_images,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path, embedding in zip(batch_paths, embeddings):
|
||||||
|
image_data.append(
|
||||||
|
{
|
||||||
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||||
|
"metadata": {
|
||||||
|
"image_path": str(img_path),
|
||||||
|
"image_name": img_path.name,
|
||||||
|
"image_dir": str(image_dir),
|
||||||
|
},
|
||||||
|
"embedding": embedding.astype(np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_paths = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process remaining images
|
||||||
|
if batch_images:
|
||||||
|
embeddings = model.encode(
|
||||||
|
batch_images,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=len(batch_images),
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for img_path, embedding in zip(batch_paths, embeddings):
|
||||||
|
image_data.append(
|
||||||
|
{
|
||||||
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||||
|
"metadata": {
|
||||||
|
"image_path": str(img_path),
|
||||||
|
"image_name": img_path.name,
|
||||||
|
"image_dir": str(image_dir),
|
||||||
|
},
|
||||||
|
"embedding": embedding.astype(np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Processed {len(image_data)} images")
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build index using pre-computed CLIP embeddings."""
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if not self._image_data or len(self._image_data) != len(texts):
|
||||||
|
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||||
|
|
||||||
|
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=self.embedding_model_default,
|
||||||
|
embedding_mode=self.embedding_mode_default,
|
||||||
|
is_recompute=False,
|
||||||
|
distance_metric="cosine",
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
build_complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text, data in zip(texts, self._image_data):
|
||||||
|
builder.add_text(text=text, metadata=data["metadata"])
|
||||||
|
|
||||||
|
ids = [str(i) for i in range(len(self._image_data))]
|
||||||
|
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||||
|
pickle.dump((ids, embeddings), f)
|
||||||
|
pkl_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||||
|
print(f"✅ Index built successfully at {index_path}")
|
||||||
|
return index_path
|
||||||
|
finally:
|
||||||
|
Path(pkl_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for the image RAG application."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
app = ImageRAG()
|
||||||
|
asyncio.run(app.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -7,7 +7,6 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
|
|||||||
flexible message processing options.
|
flexible message processing options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ast
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -147,16 +146,16 @@ class SlackMCPReader:
|
|||||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = ast.literal_eval(match.group(1))
|
error_dict = eval(match.group(1))
|
||||||
except (ValueError, SyntaxError):
|
except (ValueError, SyntaxError, NameError):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Try alternative format
|
# Try alternative format
|
||||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = ast.literal_eval(match.group(1))
|
error_dict = eval(match.group(1))
|
||||||
except (ValueError, SyntaxError):
|
except (ValueError, SyntaxError, NameError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self._is_cache_sync_error(error_dict):
|
if self._is_cache_sync_error(error_dict):
|
||||||
|
|||||||
@@ -1162,6 +1162,11 @@ Examples:
|
|||||||
print(f"Warning: Could not process {file_path}: {e}")
|
print(f"Warning: Could not process {file_path}: {e}")
|
||||||
|
|
||||||
# Load other file types with default reader
|
# Load other file types with default reader
|
||||||
|
# Exclude PDFs from code_extensions if they were already processed separately
|
||||||
|
other_file_extensions = code_extensions
|
||||||
|
if should_process_pdfs and ".pdf" in code_extensions:
|
||||||
|
other_file_extensions = [ext for ext in code_extensions if ext != ".pdf"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a custom file filter function using our PathSpec
|
# Create a custom file filter function using our PathSpec
|
||||||
def file_filter(
|
def file_filter(
|
||||||
@@ -1177,15 +1182,19 @@ Examples:
|
|||||||
except (ValueError, OSError):
|
except (ValueError, OSError):
|
||||||
return True # Include files that can't be processed
|
return True # Include files that can't be processed
|
||||||
|
|
||||||
other_docs = SimpleDirectoryReader(
|
# Only load other file types if there are extensions to process
|
||||||
docs_dir,
|
if other_file_extensions:
|
||||||
recursive=True,
|
other_docs = SimpleDirectoryReader(
|
||||||
encoding="utf-8",
|
docs_dir,
|
||||||
required_exts=code_extensions,
|
recursive=True,
|
||||||
file_extractor={}, # Use default extractors
|
encoding="utf-8",
|
||||||
exclude_hidden=not include_hidden,
|
required_exts=other_file_extensions,
|
||||||
filename_as_id=True,
|
file_extractor={}, # Use default extractors
|
||||||
).load_data(show_progress=True)
|
exclude_hidden=not include_hidden,
|
||||||
|
filename_as_id=True,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
else:
|
||||||
|
other_docs = []
|
||||||
|
|
||||||
# Filter documents after loading based on gitignore rules
|
# Filter documents after loading based on gitignore rules
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user