Merge pull request #162 from yichuan-w/feature/colqwen-integration
add ColQwen multimodal PDF retrieval integration
This commit is contained in:
364
apps/colqwen_rag.py
Normal file
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||
|
||||
Usage:
|
||||
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
# Add LEANN packages to path
|
||||
_repo_root = Path(__file__).resolve().parents[1]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
import torch # noqa: E402
|
||||
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||
from pdf2image import convert_from_path # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
from torch.utils.data import DataLoader # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
# Import the existing multi-vector implementation
|
||||
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
|
||||
class ColQwenRAG:
|
||||
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||
|
||||
def __init__(self, model_type: str = "colpali"):
|
||||
"""
|
||||
Initialize ColQwen RAG system.
|
||||
|
||||
Args:
|
||||
model_type: "colqwen2" or "colpali"
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.device = self._get_device()
|
||||
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||
if self.device.type == "mps":
|
||||
self.dtype = torch.float32
|
||||
elif self.device.type == "cuda":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||
|
||||
# Load model and processor with MPS-optimized settings
|
||||
try:
|
||||
if model_type == "colqwen2":
|
||||
self.model_name = "vidore/colqwen2-v1.0"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
else: # colpali
|
||||
self.model_name = "vidore/colpali-v1.2"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||
except Exception as e:
|
||||
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_type == "colqwen2":
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_device(self):
|
||||
"""Auto-select best available device."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||
"""
|
||||
Build multimodal index from PDF files.
|
||||
|
||||
Args:
|
||||
pdf_paths: List of PDF file paths
|
||||
index_name: Name for the index
|
||||
pages_dir: Directory to save page images (optional)
|
||||
"""
|
||||
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||
|
||||
# Convert PDFs to images
|
||||
all_images = []
|
||||
all_metadata = []
|
||||
|
||||
if pages_dir:
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=150)
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# Save image if pages_dir specified
|
||||
if pages_dir:
|
||||
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||
image.save(image_path)
|
||||
|
||||
all_images.append(image)
|
||||
all_metadata.append(
|
||||
{
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page_number": i + 1,
|
||||
"image_path": str(image_path) if pages_dir else None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {pdf_path}: {e}")
|
||||
continue
|
||||
|
||||
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||
print(f"All metadata: {all_metadata}")
|
||||
|
||||
# Generate embeddings
|
||||
print("🧠 Generating embeddings...")
|
||||
embeddings = self._embed_images(all_images)
|
||||
|
||||
# Build LEANN index
|
||||
print("🔍 Building LEANN index...")
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=embeddings.shape[-1],
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Create collection and insert data
|
||||
leann_mv.create_collection()
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||
data = {
|
||||
"doc_id": i,
|
||||
"filepath": metadata.get("image_path", ""),
|
||||
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||
}
|
||||
leann_mv.insert(data)
|
||||
|
||||
# Build the index
|
||||
leann_mv.create_index()
|
||||
print(f"✅ Index '{index_name}' built successfully!")
|
||||
|
||||
return leann_mv
|
||||
|
||||
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||
"""
|
||||
Search the index with a text query.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to search
|
||||
query: Text query
|
||||
top_k: Number of results to return
|
||||
"""
|
||||
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||
|
||||
# Load index
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=128, # Will be updated when loading
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self._embed_query(query)
|
||||
|
||||
# Search (returns list of (score, doc_id) tuples)
|
||||
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||
|
||||
# Display results
|
||||
print(f"\n📋 Top {len(search_results)} results:")
|
||||
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||
# Get metadata for this doc_id (we need to load the metadata)
|
||||
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||
|
||||
return search_results
|
||||
|
||||
def ask(self, index_name: str, interactive: bool = False):
|
||||
"""
|
||||
Interactive Q&A with the indexed documents.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to query
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||
|
||||
if interactive:
|
||||
print("Type 'quit' to exit, 'help' for commands")
|
||||
while True:
|
||||
try:
|
||||
query = input("\n🤔 Your question: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("Commands: quit/exit/q (exit), help (this message)")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
self.search(index_name, query, top_k=3)
|
||||
|
||||
# TODO: Add answer generation with Qwen-VL
|
||||
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
else:
|
||||
query = input("🤔 Your question: ").strip()
|
||||
if query:
|
||||
self.search(index_name, query)
|
||||
|
||||
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Generate embeddings for a list of images."""
|
||||
dataset = ListDataset(images)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||
|
||||
embeddings = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||
batch_images = cast(list, batch)
|
||||
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||
batch_embeddings = self.model(**batch_inputs)
|
||||
embeddings.append(batch_embeddings.cpu())
|
||||
|
||||
return torch.cat(embeddings, dim=0)
|
||||
|
||||
def _embed_query(self, query: str) -> torch.Tensor:
|
||||
"""Generate embedding for a text query."""
|
||||
with torch.no_grad():
|
||||
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||
query_embedding = self.model(**query_inputs)
|
||||
return query_embedding.cpu()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||
build_parser.add_argument("--index", required=True, help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||
search_parser.add_argument("index", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||
search_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||
ask_parser.add_argument("index", help="Index name")
|
||||
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
ask_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Initialize ColQwen RAG
|
||||
if args.command == "build":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
|
||||
# Get PDF files
|
||||
pdf_dir = Path(args.pdfs)
|
||||
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||
pdf_paths = [str(pdf_dir)]
|
||||
elif pdf_dir.is_dir():
|
||||
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||
else:
|
||||
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||
return
|
||||
|
||||
if not pdf_paths:
|
||||
print(f"❌ No PDF files found in {args.pdfs}")
|
||||
return
|
||||
|
||||
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||
|
||||
elif args.command == "search":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.search(args.index, args.query, args.top_k)
|
||||
|
||||
elif args.command == "ask":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.ask(args.index, args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
218
apps/image_rag.py
Normal file
218
apps/image_rag.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIP Image RAG Application
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||
You can index a directory of images and search them using text queries.
|
||||
|
||||
Usage:
|
||||
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
|
||||
|
||||
class ImageRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for images using CLIP embeddings.
|
||||
|
||||
This class provides a complete RAG pipeline for image data, including
|
||||
CLIP embedding generation, indexing, and text-based image search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Image RAG",
|
||||
description="RAG application for images using CLIP embeddings",
|
||||
default_index_name="image_index",
|
||||
)
|
||||
# Override default embedding model to use CLIP
|
||||
self.embedding_model_default = "clip-ViT-L-14"
|
||||
self.embedding_mode_default = "sentence-transformers"
|
||||
self._image_data: list[dict] = []
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add image-specific arguments."""
|
||||
image_group = parser.add_argument_group("Image Parameters")
|
||||
image_group.add_argument(
|
||||
"--image-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing images to index",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--image-extensions",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for CLIP embedding generation (default: 32)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||
self._image_data = self._load_images_and_embeddings(args)
|
||||
return [entry["text"] for entry in self._image_data]
|
||||
|
||||
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||
"""Helper to process images and produce embeddings/metadata."""
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||
|
||||
print(f"📸 Loading images from {image_dir}...")
|
||||
|
||||
# Find all image files
|
||||
image_files = []
|
||||
for ext in args.image_extensions:
|
||||
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(
|
||||
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||
)
|
||||
|
||||
print(f"✅ Found {len(image_files)} images")
|
||||
|
||||
# Limit if max_items is set
|
||||
if args.max_items > 0:
|
||||
image_files = image_files[: args.max_items]
|
||||
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||
|
||||
# Load CLIP model
|
||||
print("🔍 Loading CLIP model...")
|
||||
model = SentenceTransformer(self.embedding_model_default)
|
||||
|
||||
# Process images and generate embeddings
|
||||
print("🖼️ Processing images and generating embeddings...")
|
||||
image_data = []
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
for image_path in tqdm(image_files, desc="Processing images"):
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
batch_images.append(image)
|
||||
batch_paths.append(image_path)
|
||||
|
||||
# Process in batches
|
||||
if len(batch_images) >= args.batch_size:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=args.batch_size,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining images
|
||||
if batch_images:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=len(batch_images),
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Processed {len(image_data)} images")
|
||||
return image_data
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build index using pre-computed CLIP embeddings."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if not self._image_data or len(self._image_data) != len(texts):
|
||||
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||
|
||||
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=self.embedding_model_default,
|
||||
embedding_mode=self.embedding_mode_default,
|
||||
is_recompute=False,
|
||||
distance_metric="cosine",
|
||||
graph_degree=args.graph_degree,
|
||||
build_complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
)
|
||||
|
||||
for text, data in zip(texts, self._image_data):
|
||||
builder.add_text(text=text, metadata=data["metadata"])
|
||||
|
||||
ids = [str(i) for i in range(len(self._image_data))]
|
||||
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||
pickle.dump((ids, embeddings), f)
|
||||
pkl_path = f.name
|
||||
|
||||
try:
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||
print(f"✅ Index built successfully at {index_path}")
|
||||
return index_path
|
||||
finally:
|
||||
Path(pkl_path).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the image RAG application."""
|
||||
import asyncio
|
||||
|
||||
app = ImageRAG()
|
||||
asyncio.run(app.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user