change wecaht app split logic& merge
This commit is contained in:
@@ -13,10 +13,10 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from leann.api import LeannSearcher
|
from leann.api import LeannSearcher, LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
@@ -26,13 +26,32 @@ def download_data_if_needed(data_root: Path):
|
|||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(
|
if download_embeddings:
|
||||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
# Download everything including embeddings (large files)
|
||||||
repo_type="dataset",
|
snapshot_download(
|
||||||
local_dir=data_root,
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
|
repo_type="dataset",
|
||||||
)
|
local_dir=data_root,
|
||||||
print("Data download complete!")
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
print("Data download complete (including embeddings)!")
|
||||||
|
else:
|
||||||
|
# Download only specific folders, excluding embeddings
|
||||||
|
allow_patterns = [
|
||||||
|
"ground_truth/**",
|
||||||
|
"indices/**",
|
||||||
|
"queries/**",
|
||||||
|
"*.md",
|
||||||
|
"*.txt",
|
||||||
|
]
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
print("Data download complete (excluding embeddings)!")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
print(
|
||||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
@@ -44,6 +63,43 @@ def download_data_if_needed(data_root: Path):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||||
|
"""Download embeddings files specifically."""
|
||||||
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
# Check if specific dataset embeddings exist
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
print(f"Embeddings for {dataset_type} already exist")
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
print("Downloading embeddings from HuggingFace Hub...")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download only embeddings folder
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=["embeddings/**/*.pkl"],
|
||||||
|
)
|
||||||
|
print("Embeddings download complete!")
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
return str(embeddings_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading embeddings: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||||
"""
|
"""
|
||||||
@@ -72,12 +128,76 @@ def load_queries(file_path: Path) -> List[str]:
|
|||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def build_index_from_embeddings(
|
||||||
|
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||||
|
output_path: Path where to save the index
|
||||||
|
backend: Backend to use ("hnsw" or "diskann")
|
||||||
|
"""
|
||||||
|
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||||
|
|
||||||
|
# Create builder with appropriate parameters
|
||||||
|
if backend == "hnsw":
|
||||||
|
builder_kwargs = {
|
||||||
|
"M": 32, # Graph degree
|
||||||
|
"efConstruction": 256, # Construction complexity
|
||||||
|
"is_compact": True, # Use compact storage
|
||||||
|
"is_recompute": True, # Enable pruning for better recall
|
||||||
|
}
|
||||||
|
elif backend == "diskann":
|
||||||
|
builder_kwargs = {
|
||||||
|
"complexity": 64,
|
||||||
|
"graph_degree": 32,
|
||||||
|
"search_memory_maximum": 8.0, # GB
|
||||||
|
"build_memory_maximum": 16.0, # GB
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builder_kwargs = {}
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||||
|
dimensions=768, # Will be auto-detected from embeddings
|
||||||
|
**builder_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build index from precomputed embeddings
|
||||||
|
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||||
|
print(f"Index saved to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run recall evaluation on a LEANN index."
|
description="Run recall evaluation on a LEANN index."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"index_path", type=str, help="Path to the LEANN index to evaluate."
|
"index_path",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Path to the LEANN index to evaluate or build (optional).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["evaluate", "build"],
|
||||||
|
default="evaluate",
|
||||||
|
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embeddings-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to embeddings pickle file (optional for build mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
default="hnsw",
|
||||||
|
help="Backend to use for building index (default: hnsw)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
@@ -96,8 +216,90 @@ def main():
|
|||||||
project_root = Path(__file__).resolve().parent.parent
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
data_root = project_root / "data"
|
data_root = project_root / "data"
|
||||||
|
|
||||||
# Automatically download data if it doesn't exist
|
# Download data based on mode
|
||||||
download_data_if_needed(data_root)
|
if args.mode == "build":
|
||||||
|
# For building mode, we need embeddings
|
||||||
|
download_data_if_needed(
|
||||||
|
data_root, download_embeddings=False
|
||||||
|
) # Basic data first
|
||||||
|
|
||||||
|
# Auto-detect dataset type and download embeddings
|
||||||
|
if args.embeddings_file:
|
||||||
|
embeddings_file = args.embeddings_file
|
||||||
|
# Try to detect dataset type from embeddings file path
|
||||||
|
if "rpj_wiki" in str(embeddings_file):
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in str(embeddings_file):
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default
|
||||||
|
else:
|
||||||
|
# Auto-detect from index path if provided, otherwise default to DPR
|
||||||
|
if args.index_path:
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
|
||||||
|
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||||
|
|
||||||
|
# Auto-generate index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
indices_dir = data_root / "indices" / dataset_type
|
||||||
|
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||||
|
print(f"Auto-generated index path: {args.index_path}")
|
||||||
|
|
||||||
|
print(f"Building index from embeddings: {embeddings_file}")
|
||||||
|
built_index_path = build_index_from_embeddings(
|
||||||
|
embeddings_file, args.index_path, args.backend
|
||||||
|
)
|
||||||
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
|
# Ask if user wants to run evaluation
|
||||||
|
eval_response = (
|
||||||
|
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
|
)
|
||||||
|
if eval_response != "y":
|
||||||
|
print("Index building complete. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# For evaluation mode, don't need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False)
|
||||||
|
|
||||||
|
# Auto-detect index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
# Default to using downloaded indices
|
||||||
|
indices_dir = data_root / "indices"
|
||||||
|
|
||||||
|
# Try common datasets in order of preference
|
||||||
|
for dataset in ["dpr", "rpj_wiki"]:
|
||||||
|
dataset_dir = indices_dir / dataset
|
||||||
|
if dataset_dir.exists():
|
||||||
|
# Look for index files
|
||||||
|
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||||
|
dataset_dir.glob("*_disk.index")
|
||||||
|
)
|
||||||
|
if index_files:
|
||||||
|
args.index_path = str(
|
||||||
|
index_files[0].with_suffix("")
|
||||||
|
) # Remove .index extension
|
||||||
|
print(f"Using index: {args.index_path}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args.index_path:
|
||||||
|
print(
|
||||||
|
"No indices found. The data download should have included pre-built indices."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Please check the data/indices/ directory or provide --index-path manually."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Detect dataset type from index path to select the correct ground truth
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
index_path_str = str(args.index_path)
|
index_path_str = str(args.index_path)
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ from .chat import get_llm
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
chunks: List[str],
|
chunks: List[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx'
|
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks: List of text chunks to embed
|
chunks: List of text chunks to embed
|
||||||
model_name: Name of the embedding model
|
model_name: Name of the embedding model
|
||||||
@@ -35,7 +35,7 @@ def compute_embeddings(
|
|||||||
- "mlx": Use MLX backend for Apple Silicon
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
- "openai": Use OpenAI embedding API
|
- "openai": Use OpenAI embedding API
|
||||||
use_server: Whether to use embedding server (True for search, False for build)
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
numpy array of embeddings
|
numpy array of embeddings
|
||||||
"""
|
"""
|
||||||
@@ -46,33 +46,41 @@ def compute_embeddings(
|
|||||||
# Auto-detect mode based on model name if not explicitly set
|
# Auto-detect mode based on model name if not explicitly set
|
||||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||||
mode = "openai"
|
mode = "openai"
|
||||||
|
|
||||||
if mode == "mlx":
|
if mode == "mlx":
|
||||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(chunks, model_name)
|
return compute_embeddings_openai(chunks, model_name)
|
||||||
elif mode == "sentence-transformers":
|
elif mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
|
return compute_embeddings_sentence_transformers(
|
||||||
|
chunks, model_name, use_server=use_server
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
|
raise ValueError(
|
||||||
|
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
|
def compute_embeddings_sentence_transformers(
|
||||||
|
chunks: List[str], model_name: str, use_server: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks: List of text chunks to embed
|
chunks: List of text chunks to embed
|
||||||
model_name: Name of the sentence transformer model
|
model_name: Name of the sentence transformer model
|
||||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||||
"""
|
"""
|
||||||
if not use_server:
|
if not use_server:
|
||||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||||
|
)
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use embedding server for sentence-transformers too
|
# Use embedding server for sentence-transformers too
|
||||||
# This avoids loading the model twice (once in API, once in server)
|
# This avoids loading the model twice (once in API, once in server)
|
||||||
try:
|
try:
|
||||||
@@ -81,49 +89,55 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
|
|||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .embedding_server_manager import EmbeddingServerManager
|
from .embedding_server_manager import EmbeddingServerManager
|
||||||
|
|
||||||
# Ensure embedding server is running
|
# Ensure embedding server is running
|
||||||
port = 5557
|
port = 5557
|
||||||
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
|
||||||
server_started = server_manager.start_server(
|
server_started = server_manager.start_server(
|
||||||
port=port,
|
port=port,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
embedding_mode="sentence-transformers",
|
embedding_mode="sentence-transformers",
|
||||||
enable_warmup=False,
|
enable_warmup=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||||
|
|
||||||
# Connect to embedding server
|
# Connect to embedding server
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REQ)
|
socket = context.socket(zmq.REQ)
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
# Send chunks to server for embedding computation
|
# Send chunks to server for embedding computation
|
||||||
request = chunks
|
request = chunks
|
||||||
socket.send(msgpack.packb(request))
|
socket.send(msgpack.packb(request))
|
||||||
|
|
||||||
# Receive embeddings from server
|
# Receive embeddings from server
|
||||||
response = socket.recv()
|
response = socket.recv()
|
||||||
embeddings_list = msgpack.unpackb(response)
|
embeddings_list = msgpack.unpackb(response)
|
||||||
|
|
||||||
# Convert back to numpy array
|
# Convert back to numpy array
|
||||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
|
|
||||||
socket.close()
|
socket.close()
|
||||||
context.term()
|
context.term()
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Fallback to direct sentence-transformers if server connection fails
|
# Fallback to direct sentence-transformers if server connection fails
|
||||||
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
print(
|
||||||
|
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||||
|
)
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||||
|
|
||||||
|
|
||||||
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
def _compute_embeddings_sentence_transformers_direct(
|
||||||
|
chunks: List[str], model_name: str
|
||||||
|
) -> np.ndarray:
|
||||||
"""Direct sentence-transformers computation (fallback)."""
|
"""Direct sentence-transformers computation (fallback)."""
|
||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
@@ -164,16 +178,18 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"openai not available. Install with: uv pip install openai"
|
"openai not available. Install with: uv pip install openai"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# Get API key from environment
|
# Get API key from environment
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=api_key)
|
client = openai.OpenAI(api_key=api_key)
|
||||||
|
|
||||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||||
|
)
|
||||||
|
|
||||||
# OpenAI has a limit on batch size and input length
|
# OpenAI has a limit on batch size and input length
|
||||||
max_batch_size = 100 # Conservative batch size
|
max_batch_size = 100 # Conservative batch size
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
@@ -191,18 +207,17 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
|||||||
batch_chunks = chunks[i:i + max_batch_size]
|
batch_chunks = chunks[i:i + max_batch_size]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||||
model=model_name,
|
|
||||||
input=batch_chunks
|
|
||||||
)
|
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
all_embeddings.extend(batch_embeddings)
|
all_embeddings.extend(batch_embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}")
|
print(
|
||||||
|
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -345,7 +360,12 @@ class LeannBuilder:
|
|||||||
raise ValueError("No chunks added.")
|
raise ValueError("No chunks added.")
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = len(
|
self.dimensions = len(
|
||||||
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
|
compute_embeddings(
|
||||||
|
["dummy"],
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
)[0]
|
||||||
)
|
)
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -414,6 +434,129 @@ class LeannBuilder:
|
|||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
|
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
|
||||||
|
"""
|
||||||
|
Build an index from pre-computed embeddings stored in a pickle file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path where the index will be saved
|
||||||
|
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
|
||||||
|
"""
|
||||||
|
# Load pre-computed embeddings
|
||||||
|
with open(embeddings_file, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
if not isinstance(data, tuple) or len(data) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ids, embeddings = data
|
||||||
|
|
||||||
|
if not isinstance(embeddings, np.ndarray):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected embeddings to be numpy array, got {type(embeddings)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate/set dimensions
|
||||||
|
embedding_dim = embeddings.shape[1]
|
||||||
|
if self.dimensions is None:
|
||||||
|
self.dimensions = embedding_dim
|
||||||
|
elif self.dimensions != embedding_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure we have text data for each embedding
|
||||||
|
if len(self.chunks) != len(ids):
|
||||||
|
# If no text chunks provided, create placeholder text entries
|
||||||
|
if not self.chunks:
|
||||||
|
print("No text chunks provided, creating placeholder entries...")
|
||||||
|
for id_val in ids:
|
||||||
|
self.add_text(
|
||||||
|
f"Document {id_val}",
|
||||||
|
metadata={"id": str(id_val), "from_embeddings": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build file structure
|
||||||
|
path = Path(index_path)
|
||||||
|
index_dir = path.parent
|
||||||
|
index_name = path.name
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||||
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||||
|
|
||||||
|
# Write passages and create offset map
|
||||||
|
offset_map = {}
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
for chunk in self.chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk["metadata"],
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
# Build the vector index using precomputed embeddings
|
||||||
|
string_ids = [str(id_val) for id_val in ids]
|
||||||
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
|
builder_instance.build(embeddings, string_ids, index_path)
|
||||||
|
|
||||||
|
# Create metadata file
|
||||||
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
|
meta_data = {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": self.backend_name,
|
||||||
|
"embedding_model": self.embedding_model,
|
||||||
|
"dimensions": self.dimensions,
|
||||||
|
"backend_kwargs": self.backend_kwargs,
|
||||||
|
"embedding_mode": self.embedding_mode,
|
||||||
|
"passage_sources": [
|
||||||
|
{
|
||||||
|
"type": "jsonl",
|
||||||
|
"path": str(passages_file),
|
||||||
|
"index_path": str(offset_file),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"built_from_precomputed_embeddings": True,
|
||||||
|
"embeddings_source": str(embeddings_file),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add storage status flags for HNSW backend
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
|
meta_data["is_compact"] = is_compact
|
||||||
|
meta_data["is_pruned"] = is_compact and is_recompute
|
||||||
|
|
||||||
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
@@ -425,7 +568,9 @@ class LeannSearcher:
|
|||||||
backend_name = self.meta_data["backend_name"]
|
backend_name = self.meta_data["backend_name"]
|
||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get(
|
||||||
|
"embedding_mode", "sentence-transformers"
|
||||||
|
)
|
||||||
# Backward compatibility with use_mlx
|
# Backward compatibility with use_mlx
|
||||||
if self.meta_data.get("use_mlx", False):
|
if self.meta_data.get("use_mlx", False):
|
||||||
self.embedding_mode = "mlx"
|
self.embedding_mode = "mlx"
|
||||||
@@ -457,6 +602,7 @@ class LeannSearcher:
|
|||||||
# Use backend's compute_query_embedding method
|
# Use backend's compute_query_embedding method
|
||||||
# This will automatically use embedding server if available and needed
|
# This will automatically use embedding server if available and needed
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||||
@@ -556,7 +702,7 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
ans=self.llm.ask(prompt, **llm_kwargs)
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user