- Fix slack_mcp_reader.py channel parameter can be None - Fix embedding_compute.py ContextProp type issue - Fix searcher_base.py method override signatures - Fix chunking_utils.py chunk_text assignment - Fix slack_rag.py and twitter_rag.py return types - Fix email.py and image_rag.py method overrides
220 lines
7.8 KiB
Python
220 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CLIP Image RAG Application
|
|
|
|
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
|
You can index a directory of images and search them using text queries.
|
|
|
|
Usage:
|
|
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
|
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
|
"""
|
|
|
|
import argparse
|
|
import pickle
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from sentence_transformers import SentenceTransformer
|
|
from tqdm import tqdm
|
|
|
|
from apps.base_rag_example import BaseRAGExample
|
|
|
|
|
|
class ImageRAG(BaseRAGExample):
|
|
"""
|
|
RAG application for images using CLIP embeddings.
|
|
|
|
This class provides a complete RAG pipeline for image data, including
|
|
CLIP embedding generation, indexing, and text-based image search.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
name="Image RAG",
|
|
description="RAG application for images using CLIP embeddings",
|
|
default_index_name="image_index",
|
|
)
|
|
# Override default embedding model to use CLIP
|
|
self.embedding_model_default = "clip-ViT-L-14"
|
|
self.embedding_mode_default = "sentence-transformers"
|
|
self._image_data: list[dict] = []
|
|
|
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
"""Add image-specific arguments."""
|
|
image_group = parser.add_argument_group("Image Parameters")
|
|
image_group.add_argument(
|
|
"--image-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Directory containing images to index",
|
|
)
|
|
image_group.add_argument(
|
|
"--image-extensions",
|
|
type=str,
|
|
nargs="+",
|
|
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
|
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
|
)
|
|
image_group.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=32,
|
|
help="Batch size for CLIP embedding generation (default: 32)",
|
|
)
|
|
|
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
|
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
|
self._image_data = self._load_images_and_embeddings(args)
|
|
return [entry["text"] for entry in self._image_data]
|
|
|
|
def _load_images_and_embeddings(self, args) -> list[dict]:
|
|
"""Helper to process images and produce embeddings/metadata."""
|
|
image_dir = Path(args.image_dir)
|
|
if not image_dir.exists():
|
|
raise ValueError(f"Image directory does not exist: {image_dir}")
|
|
|
|
print(f"📸 Loading images from {image_dir}...")
|
|
|
|
# Find all image files
|
|
image_files = []
|
|
for ext in args.image_extensions:
|
|
image_files.extend(image_dir.rglob(f"*{ext}"))
|
|
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
|
|
|
if not image_files:
|
|
raise ValueError(
|
|
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
|
)
|
|
|
|
print(f"✅ Found {len(image_files)} images")
|
|
|
|
# Limit if max_items is set
|
|
if args.max_items > 0:
|
|
image_files = image_files[: args.max_items]
|
|
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
|
|
|
# Load CLIP model
|
|
print("🔍 Loading CLIP model...")
|
|
model = SentenceTransformer(self.embedding_model_default)
|
|
|
|
# Process images and generate embeddings
|
|
print("🖼️ Processing images and generating embeddings...")
|
|
image_data = []
|
|
batch_images = []
|
|
batch_paths = []
|
|
|
|
for image_path in tqdm(image_files, desc="Processing images"):
|
|
try:
|
|
image = Image.open(image_path).convert("RGB")
|
|
batch_images.append(image)
|
|
batch_paths.append(image_path)
|
|
|
|
# Process in batches
|
|
if len(batch_images) >= args.batch_size:
|
|
embeddings = model.encode(
|
|
batch_images,
|
|
convert_to_numpy=True,
|
|
normalize_embeddings=True,
|
|
batch_size=args.batch_size,
|
|
show_progress_bar=False,
|
|
)
|
|
|
|
for img_path, embedding in zip(batch_paths, embeddings):
|
|
image_data.append(
|
|
{
|
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
|
"metadata": {
|
|
"image_path": str(img_path),
|
|
"image_name": img_path.name,
|
|
"image_dir": str(image_dir),
|
|
},
|
|
"embedding": embedding.astype(np.float32),
|
|
}
|
|
)
|
|
|
|
batch_images = []
|
|
batch_paths = []
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to process {image_path}: {e}")
|
|
continue
|
|
|
|
# Process remaining images
|
|
if batch_images:
|
|
embeddings = model.encode(
|
|
batch_images,
|
|
convert_to_numpy=True,
|
|
normalize_embeddings=True,
|
|
batch_size=len(batch_images),
|
|
show_progress_bar=False,
|
|
)
|
|
|
|
for img_path, embedding in zip(batch_paths, embeddings):
|
|
image_data.append(
|
|
{
|
|
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
|
"metadata": {
|
|
"image_path": str(img_path),
|
|
"image_name": img_path.name,
|
|
"image_dir": str(image_dir),
|
|
},
|
|
"embedding": embedding.astype(np.float32),
|
|
}
|
|
)
|
|
|
|
print(f"✅ Processed {len(image_data)} images")
|
|
return image_data
|
|
|
|
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
|
"""Build index using pre-computed CLIP embeddings."""
|
|
from leann.api import LeannBuilder
|
|
|
|
if not self._image_data or len(self._image_data) != len(texts):
|
|
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
|
|
|
print("🔨 Building LEANN index with CLIP embeddings...")
|
|
builder = LeannBuilder(
|
|
backend_name=args.backend_name,
|
|
embedding_model=self.embedding_model_default,
|
|
embedding_mode=self.embedding_mode_default,
|
|
is_recompute=False,
|
|
distance_metric="cosine",
|
|
graph_degree=args.graph_degree,
|
|
build_complexity=args.build_complexity,
|
|
is_compact=not args.no_compact,
|
|
)
|
|
|
|
for text, data in zip(texts, self._image_data):
|
|
builder.add_text(text=text, metadata=data["metadata"])
|
|
|
|
ids = [str(i) for i in range(len(self._image_data))]
|
|
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
|
|
|
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
|
pickle.dump((ids, embeddings), f)
|
|
pkl_path = f.name
|
|
|
|
try:
|
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
|
builder.build_index_from_embeddings(index_path, pkl_path)
|
|
print(f"✅ Index built successfully at {index_path}")
|
|
return index_path
|
|
finally:
|
|
Path(pkl_path).unlink()
|
|
|
|
|
|
def main():
|
|
"""Main entry point for the image RAG application."""
|
|
import asyncio
|
|
|
|
app = ImageRAG()
|
|
asyncio.run(app.run())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|