change wecaht app split logic& merge

This commit is contained in:
yichuan520030910320
2025-07-19 19:44:33 -07:00
2 changed files with 400 additions and 52 deletions

View File

@@ -13,10 +13,10 @@ import sys
import numpy as np import numpy as np
from typing import List from typing import List
from leann.api import LeannSearcher from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path): def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub.""" """Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists(): if not data_root.exists():
print(f"Data directory '{data_root}' not found.") print(f"Data directory '{data_root}' not found.")
@@ -26,13 +26,32 @@ def download_data_if_needed(data_root: Path):
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( if download_embeddings:
repo_id="LEANN-RAG/leann-rag-evaluation-data", # Download everything including embeddings (large files)
repo_type="dataset", snapshot_download(
local_dir=data_root, repo_id="LEANN-RAG/leann-rag-evaluation-data",
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure repo_type="dataset",
) local_dir=data_root,
print("Data download complete!") local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError: except ImportError:
print( print(
"Error: huggingface_hub is not installed. Please install it to download the data:" "Error: huggingface_hub is not installed. Please install it to download the data:"
@@ -44,6 +63,43 @@ def download_data_if_needed(data_root: Path):
sys.exit(1) sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages --- # --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set: def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
""" """
@@ -72,12 +128,76 @@ def load_queries(file_path: Path) -> List[str]:
return queries return queries
def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index." description="Run recall evaluation on a LEANN index."
) )
parser.add_argument( parser.add_argument(
"index_path", type=str, help="Path to the LEANN index to evaluate." "index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
) )
parser.add_argument( parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate." "--num-queries", type=int, default=10, help="Number of queries to evaluate."
@@ -96,8 +216,90 @@ def main():
project_root = Path(__file__).resolve().parent.parent project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data" data_root = project_root / "data"
# Automatically download data if it doesn't exist # Download data based on mode
download_data_if_needed(data_root) if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print(
"No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth # Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path) index_path_str = str(args.index_path)

View File

@@ -22,7 +22,7 @@ def compute_embeddings(
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
use_server: bool = True, use_server: bool = True,
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx' use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
) -> np.ndarray: ) -> np.ndarray:
""" """
Computes embeddings using different backends. Computes embeddings using different backends.
@@ -52,12 +52,18 @@ def compute_embeddings(
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai(chunks, model_name) return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers": elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server) return compute_embeddings_sentence_transformers(
chunks, model_name, use_server=use_server
)
else: else:
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai") raise ValueError(
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
)
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray: def compute_embeddings_sentence_transformers(
chunks: List[str], model_name: str, use_server: bool = True
) -> np.ndarray:
"""Computes embeddings using sentence-transformers. """Computes embeddings using sentence-transformers.
Args: Args:
@@ -66,7 +72,9 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build). use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
""" """
if not use_server: if not use_server:
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...") print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name) return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print( print(
@@ -84,7 +92,9 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
# Ensure embedding server is running # Ensure embedding server is running
port = 5557 port = 5557
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server") server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started = server_manager.start_server( server_started = server_manager.start_server(
port=port, port=port,
@@ -119,11 +129,15 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
except Exception as e: except Exception as e:
# Fallback to direct sentence-transformers if server connection fails # Fallback to direct sentence-transformers if server connection fails
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}") print(
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name) return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray: def _compute_embeddings_sentence_transformers_direct(
chunks: List[str], model_name: str
) -> np.ndarray:
"""Direct sentence-transformers computation (fallback).""" """Direct sentence-transformers computation (fallback)."""
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@@ -172,7 +186,9 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
client = openai.OpenAI(api_key=api_key) client = openai.OpenAI(api_key=api_key)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...") print(
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
)
# OpenAI has a limit on batch size and input length # OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size max_batch_size = 100 # Conservative batch size
@@ -191,10 +207,7 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
batch_chunks = chunks[i:i + max_batch_size] batch_chunks = chunks[i:i + max_batch_size]
try: try:
response = client.embeddings.create( response = client.embeddings.create(model=model_name, input=batch_chunks)
model=model_name,
input=batch_chunks
)
batch_embeddings = [embedding.embedding for embedding in response.data] batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings) all_embeddings.extend(batch_embeddings)
except Exception as e: except Exception as e:
@@ -202,7 +215,9 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
raise raise
embeddings = np.array(all_embeddings, dtype=np.float32) embeddings = np.array(all_embeddings, dtype=np.float32)
print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}") print(
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
)
return embeddings return embeddings
@@ -345,7 +360,12 @@ class LeannBuilder:
raise ValueError("No chunks added.") raise ValueError("No chunks added.")
if self.dimensions is None: if self.dimensions is None:
self.dimensions = len( self.dimensions = len(
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0] compute_embeddings(
["dummy"],
self.embedding_model,
self.embedding_mode,
use_server=False,
)[0]
) )
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
@@ -414,6 +434,129 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
"""
Build an index from pre-computed embeddings stored in a pickle file.
Args:
index_path: Path where the index will be saved
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
"""
# Load pre-computed embeddings
with open(embeddings_file, "rb") as f:
data = pickle.load(f)
if not isinstance(data, tuple) or len(data) != 2:
raise ValueError(
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
)
ids, embeddings = data
if not isinstance(embeddings, np.ndarray):
raise ValueError(
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
)
# Validate/set dimensions
embedding_dim = embeddings.shape[1]
if self.dimensions is None:
self.dimensions = embedding_dim
elif self.dimensions != embedding_dim:
raise ValueError(
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
print(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
)
# Ensure we have text data for each embedding
if len(self.chunks) != len(ids):
# If no text chunks provided, create placeholder text entries
if not self.chunks:
print("No text chunks provided, creating placeholder entries...")
for id_val in ids:
self.add_text(
f"Document {id_val}",
metadata={"id": str(id_val), "from_embeddings": True},
)
else:
raise ValueError(
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
)
# Build file structure
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_dir.mkdir(parents=True, exist_ok=True)
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
# Write passages and create offset map
offset_map = {}
with open(passages_file, "w", encoding="utf-8") as f:
for chunk in self.chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"],
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)
# Create metadata file
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
"backend_name": self.backend_name,
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"embedding_mode": self.embedding_mode,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file),
}
],
"built_from_precomputed_embeddings": True,
"embeddings_source": str(embeddings_file),
}
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
print(f"Index built successfully from precomputed embeddings: {index_path}")
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -425,7 +568,9 @@ class LeannSearcher:
backend_name = self.meta_data["backend_name"] backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx # Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False): if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx" self.embedding_mode = "mlx"
@@ -457,6 +602,7 @@ class LeannSearcher:
# Use backend's compute_query_embedding method # Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed # This will automatically use embedding server if available and needed
import time import time
start_time = time.time() start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port) query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
@@ -556,7 +702,7 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
ans=self.llm.ask(prompt, **llm_kwargs) ans = self.llm.ask(prompt, **llm_kwargs)
return ans return ans
def start_interactive(self): def start_interactive(self):