change wecaht app split logic& merge

This commit is contained in:
yichuan520030910320
2025-07-19 19:44:33 -07:00
2 changed files with 400 additions and 52 deletions

View File

@@ -13,10 +13,10 @@ import sys
import numpy as np import numpy as np
from typing import List from typing import List
from leann.api import LeannSearcher from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path): def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub.""" """Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists(): if not data_root.exists():
print(f"Data directory '{data_root}' not found.") print(f"Data directory '{data_root}' not found.")
@@ -26,13 +26,32 @@ def download_data_if_needed(data_root: Path):
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( if download_embeddings:
repo_id="LEANN-RAG/leann-rag-evaluation-data", # Download everything including embeddings (large files)
repo_type="dataset", snapshot_download(
local_dir=data_root, repo_id="LEANN-RAG/leann-rag-evaluation-data",
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure repo_type="dataset",
) local_dir=data_root,
print("Data download complete!") local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError: except ImportError:
print( print(
"Error: huggingface_hub is not installed. Please install it to download the data:" "Error: huggingface_hub is not installed. Please install it to download the data:"
@@ -44,6 +63,43 @@ def download_data_if_needed(data_root: Path):
sys.exit(1) sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages --- # --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set: def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
""" """
@@ -72,12 +128,76 @@ def load_queries(file_path: Path) -> List[str]:
return queries return queries
def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index." description="Run recall evaluation on a LEANN index."
) )
parser.add_argument( parser.add_argument(
"index_path", type=str, help="Path to the LEANN index to evaluate." "index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
) )
parser.add_argument( parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate." "--num-queries", type=int, default=10, help="Number of queries to evaluate."
@@ -96,8 +216,90 @@ def main():
project_root = Path(__file__).resolve().parent.parent project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data" data_root = project_root / "data"
# Automatically download data if it doesn't exist # Download data based on mode
download_data_if_needed(data_root) if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print(
"No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth # Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path) index_path_str = str(args.index_path)

View File

@@ -18,15 +18,15 @@ from .chat import get_llm
def compute_embeddings( def compute_embeddings(
chunks: List[str], chunks: List[str],
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
use_server: bool = True, use_server: bool = True,
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx' use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
) -> np.ndarray: ) -> np.ndarray:
""" """
Computes embeddings using different backends. Computes embeddings using different backends.
Args: Args:
chunks: List of text chunks to embed chunks: List of text chunks to embed
model_name: Name of the embedding model model_name: Name of the embedding model
@@ -35,7 +35,7 @@ def compute_embeddings(
- "mlx": Use MLX backend for Apple Silicon - "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API - "openai": Use OpenAI embedding API
use_server: Whether to use embedding server (True for search, False for build) use_server: Whether to use embedding server (True for search, False for build)
Returns: Returns:
numpy array of embeddings numpy array of embeddings
""" """
@@ -46,33 +46,41 @@ def compute_embeddings(
# Auto-detect mode based on model name if not explicitly set # Auto-detect mode based on model name if not explicitly set
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"): if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
mode = "openai" mode = "openai"
if mode == "mlx": if mode == "mlx":
return compute_embeddings_mlx(chunks, model_name, batch_size=16) return compute_embeddings_mlx(chunks, model_name, batch_size=16)
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai(chunks, model_name) return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers": elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server) return compute_embeddings_sentence_transformers(
chunks, model_name, use_server=use_server
)
else: else:
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai") raise ValueError(
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
)
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray: def compute_embeddings_sentence_transformers(
chunks: List[str], model_name: str, use_server: bool = True
) -> np.ndarray:
"""Computes embeddings using sentence-transformers. """Computes embeddings using sentence-transformers.
Args: Args:
chunks: List of text chunks to embed chunks: List of text chunks to embed
model_name: Name of the sentence transformer model model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build). use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
""" """
if not use_server: if not use_server:
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...") print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name) return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print( print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
) )
# Use embedding server for sentence-transformers too # Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server) # This avoids loading the model twice (once in API, once in server)
try: try:
@@ -81,49 +89,55 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
import msgpack import msgpack
import numpy as np import numpy as np
from .embedding_server_manager import EmbeddingServerManager from .embedding_server_manager import EmbeddingServerManager
# Ensure embedding server is running # Ensure embedding server is running
port = 5557 port = 5557
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server") server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started = server_manager.start_server( server_started = server_manager.start_server(
port=port, port=port,
model_name=model_name, model_name=model_name,
embedding_mode="sentence-transformers", embedding_mode="sentence-transformers",
enable_warmup=False, enable_warmup=False,
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}") raise RuntimeError(f"Failed to start embedding server on port {port}")
# Connect to embedding server # Connect to embedding server
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}") socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation # Send chunks to server for embedding computation
request = chunks request = chunks
socket.send(msgpack.packb(request)) socket.send(msgpack.packb(request))
# Receive embeddings from server # Receive embeddings from server
response = socket.recv() response = socket.recv()
embeddings_list = msgpack.unpackb(response) embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array # Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32) embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close() socket.close()
context.term() context.term()
return embeddings return embeddings
except Exception as e: except Exception as e:
# Fallback to direct sentence-transformers if server connection fails # Fallback to direct sentence-transformers if server connection fails
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}") print(
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name) return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray: def _compute_embeddings_sentence_transformers_direct(
chunks: List[str], model_name: str
) -> np.ndarray:
"""Direct sentence-transformers computation (fallback).""" """Direct sentence-transformers computation (fallback)."""
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@@ -164,16 +178,18 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
raise RuntimeError( raise RuntimeError(
"openai not available. Install with: uv pip install openai" "openai not available. Install with: uv pip install openai"
) from e ) from e
# Get API key from environment # Get API key from environment
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
if not api_key: if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set") raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key) client = openai.OpenAI(api_key=api_key)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...") print(
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
)
# OpenAI has a limit on batch size and input length # OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size max_batch_size = 100 # Conservative batch size
all_embeddings = [] all_embeddings = []
@@ -191,18 +207,17 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
batch_chunks = chunks[i:i + max_batch_size] batch_chunks = chunks[i:i + max_batch_size]
try: try:
response = client.embeddings.create( response = client.embeddings.create(model=model_name, input=batch_chunks)
model=model_name,
input=batch_chunks
)
batch_embeddings = [embedding.embedding for embedding in response.data] batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings) all_embeddings.extend(batch_embeddings)
except Exception as e: except Exception as e:
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}") print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
raise raise
embeddings = np.array(all_embeddings, dtype=np.float32) embeddings = np.array(all_embeddings, dtype=np.float32)
print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}") print(
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
)
return embeddings return embeddings
@@ -345,7 +360,12 @@ class LeannBuilder:
raise ValueError("No chunks added.") raise ValueError("No chunks added.")
if self.dimensions is None: if self.dimensions is None:
self.dimensions = len( self.dimensions = len(
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0] compute_embeddings(
["dummy"],
self.embedding_model,
self.embedding_mode,
use_server=False,
)[0]
) )
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
@@ -414,6 +434,129 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
"""
Build an index from pre-computed embeddings stored in a pickle file.
Args:
index_path: Path where the index will be saved
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
"""
# Load pre-computed embeddings
with open(embeddings_file, "rb") as f:
data = pickle.load(f)
if not isinstance(data, tuple) or len(data) != 2:
raise ValueError(
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
)
ids, embeddings = data
if not isinstance(embeddings, np.ndarray):
raise ValueError(
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
)
# Validate/set dimensions
embedding_dim = embeddings.shape[1]
if self.dimensions is None:
self.dimensions = embedding_dim
elif self.dimensions != embedding_dim:
raise ValueError(
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
print(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
)
# Ensure we have text data for each embedding
if len(self.chunks) != len(ids):
# If no text chunks provided, create placeholder text entries
if not self.chunks:
print("No text chunks provided, creating placeholder entries...")
for id_val in ids:
self.add_text(
f"Document {id_val}",
metadata={"id": str(id_val), "from_embeddings": True},
)
else:
raise ValueError(
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
)
# Build file structure
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_dir.mkdir(parents=True, exist_ok=True)
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
# Write passages and create offset map
offset_map = {}
with open(passages_file, "w", encoding="utf-8") as f:
for chunk in self.chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"],
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)
# Create metadata file
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
"backend_name": self.backend_name,
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"embedding_mode": self.embedding_mode,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file),
}
],
"built_from_precomputed_embeddings": True,
"embeddings_source": str(embeddings_file),
}
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
print(f"Index built successfully from precomputed embeddings: {index_path}")
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -425,7 +568,9 @@ class LeannSearcher:
backend_name = self.meta_data["backend_name"] backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx # Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False): if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx" self.embedding_mode = "mlx"
@@ -457,6 +602,7 @@ class LeannSearcher:
# Use backend's compute_query_embedding method # Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed # This will automatically use embedding server if available and needed
import time import time
start_time = time.time() start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port) query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
@@ -556,7 +702,7 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
ans=self.llm.ask(prompt, **llm_kwargs) ans = self.llm.ask(prompt, **llm_kwargs)
return ans return ans
def start_interactive(self): def start_interactive(self):