change wecaht app split logic& merge
This commit is contained in:
@@ -13,10 +13,10 @@ import sys
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
from leann.api import LeannSearcher, LeannBuilder
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path):
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
@@ -26,13 +26,32 @@ def download_data_if_needed(data_root: Path):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
|
||||
)
|
||||
print("Data download complete!")
|
||||
if download_embeddings:
|
||||
# Download everything including embeddings (large files)
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("Data download complete (including embeddings)!")
|
||||
else:
|
||||
# Download only specific folders, excluding embeddings
|
||||
allow_patterns = [
|
||||
"ground_truth/**",
|
||||
"indices/**",
|
||||
"queries/**",
|
||||
"*.md",
|
||||
"*.txt",
|
||||
]
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
print("Data download complete (excluding embeddings)!")
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
@@ -44,6 +63,43 @@ def download_data_if_needed(data_root: Path):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
if dataset_type:
|
||||
# Check if specific dataset embeddings exist
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
print(f"Embeddings for {dataset_type} already exist")
|
||||
return str(target_file)
|
||||
|
||||
print("Downloading embeddings from HuggingFace Hub...")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download only embeddings folder
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=["embeddings/**/*.pkl"],
|
||||
)
|
||||
print("Embeddings download complete!")
|
||||
|
||||
if dataset_type:
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
return str(target_file)
|
||||
|
||||
return str(embeddings_dir)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error downloading embeddings: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
"""
|
||||
@@ -72,12 +128,76 @@ def load_queries(file_path: Path) -> List[str]:
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(
|
||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
Args:
|
||||
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||
output_path: Path where to save the index
|
||||
backend: Backend to use ("hnsw" or "diskann")
|
||||
"""
|
||||
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||
|
||||
# Create builder with appropriate parameters
|
||||
if backend == "hnsw":
|
||||
builder_kwargs = {
|
||||
"M": 32, # Graph degree
|
||||
"efConstruction": 256, # Construction complexity
|
||||
"is_compact": True, # Use compact storage
|
||||
"is_recompute": True, # Enable pruning for better recall
|
||||
}
|
||||
elif backend == "diskann":
|
||||
builder_kwargs = {
|
||||
"complexity": 64,
|
||||
"graph_degree": 32,
|
||||
"search_memory_maximum": 8.0, # GB
|
||||
"build_memory_maximum": 16.0, # GB
|
||||
}
|
||||
else:
|
||||
builder_kwargs = {}
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||
dimensions=768, # Will be auto-detected from embeddings
|
||||
**builder_kwargs,
|
||||
)
|
||||
|
||||
# Build index from precomputed embeddings
|
||||
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||
print(f"Index saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser.add_argument(
|
||||
"index_path", type=str, help="Path to the LEANN index to evaluate."
|
||||
"index_path",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path to the LEANN index to evaluate or build (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["evaluate", "build"],
|
||||
default="evaluate",
|
||||
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embeddings-file",
|
||||
type=str,
|
||||
help="Path to embeddings pickle file (optional for build mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["hnsw", "diskann"],
|
||||
default="hnsw",
|
||||
help="Backend to use for building index (default: hnsw)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
@@ -96,8 +216,90 @@ def main():
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
|
||||
# Automatically download data if it doesn't exist
|
||||
download_data_if_needed(data_root)
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(
|
||||
data_root, download_embeddings=False
|
||||
) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
embeddings_file = args.embeddings_file
|
||||
# Try to detect dataset type from embeddings file path
|
||||
if "rpj_wiki" in str(embeddings_file):
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in str(embeddings_file):
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default
|
||||
else:
|
||||
# Auto-detect from index path if provided, otherwise default to DPR
|
||||
if args.index_path:
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
|
||||
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||
|
||||
# Auto-generate index path if not provided
|
||||
if not args.index_path:
|
||||
indices_dir = data_root / "indices" / dataset_type
|
||||
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||
print(f"Auto-generated index path: {args.index_path}")
|
||||
|
||||
print(f"Building index from embeddings: {embeddings_file}")
|
||||
built_index_path = build_index_from_embeddings(
|
||||
embeddings_file, args.index_path, args.backend
|
||||
)
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = (
|
||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
)
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
else:
|
||||
# For evaluation mode, don't need embeddings
|
||||
download_data_if_needed(data_root, download_embeddings=False)
|
||||
|
||||
# Auto-detect index path if not provided
|
||||
if not args.index_path:
|
||||
# Default to using downloaded indices
|
||||
indices_dir = data_root / "indices"
|
||||
|
||||
# Try common datasets in order of preference
|
||||
for dataset in ["dpr", "rpj_wiki"]:
|
||||
dataset_dir = indices_dir / dataset
|
||||
if dataset_dir.exists():
|
||||
# Look for index files
|
||||
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||
dataset_dir.glob("*_disk.index")
|
||||
)
|
||||
if index_files:
|
||||
args.index_path = str(
|
||||
index_files[0].with_suffix("")
|
||||
) # Remove .index extension
|
||||
print(f"Using index: {args.index_path}")
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print(
|
||||
"No indices found. The data download should have included pre-built indices."
|
||||
)
|
||||
print(
|
||||
"Please check the data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
index_path_str = str(args.index_path)
|
||||
|
||||
@@ -18,15 +18,15 @@ from .chat import get_llm
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: List[str],
|
||||
model_name: str,
|
||||
chunks: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx'
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the embedding model
|
||||
@@ -35,7 +35,7 @@ def compute_embeddings(
|
||||
- "mlx": Use MLX backend for Apple Silicon
|
||||
- "openai": Use OpenAI embedding API
|
||||
use_server: Whether to use embedding server (True for search, False for build)
|
||||
|
||||
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
"""
|
||||
@@ -46,33 +46,41 @@ def compute_embeddings(
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
mode = "openai"
|
||||
|
||||
|
||||
if mode == "mlx":
|
||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
|
||||
return compute_embeddings_sentence_transformers(
|
||||
chunks, model_name, use_server=use_server
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
raise ValueError(
|
||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
|
||||
def compute_embeddings_sentence_transformers(
|
||||
chunks: List[str], model_name: str, use_server: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
if not use_server:
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
@@ -81,49 +89,55 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
||||
print(
|
||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
def _compute_embeddings_sentence_transformers_direct(
|
||||
chunks: List[str], model_name: str
|
||||
) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
@@ -164,16 +178,18 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
raise RuntimeError(
|
||||
"openai not available. Install with: uv pip install openai"
|
||||
) from e
|
||||
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
|
||||
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||
)
|
||||
|
||||
# OpenAI has a limit on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
@@ -191,18 +207,17 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
batch_chunks = chunks[i:i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch_chunks
|
||||
)
|
||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}")
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -345,7 +360,12 @@ class LeannBuilder:
|
||||
raise ValueError("No chunks added.")
|
||||
if self.dimensions is None:
|
||||
self.dimensions = len(
|
||||
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
|
||||
compute_embeddings(
|
||||
["dummy"],
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
)[0]
|
||||
)
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -414,6 +434,129 @@ class LeannBuilder:
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
|
||||
"""
|
||||
Build an index from pre-computed embeddings stored in a pickle file.
|
||||
|
||||
Args:
|
||||
index_path: Path where the index will be saved
|
||||
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
|
||||
"""
|
||||
# Load pre-computed embeddings
|
||||
with open(embeddings_file, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
if not isinstance(data, tuple) or len(data) != 2:
|
||||
raise ValueError(
|
||||
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
|
||||
)
|
||||
|
||||
ids, embeddings = data
|
||||
|
||||
if not isinstance(embeddings, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be numpy array, got {type(embeddings)}"
|
||||
)
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
|
||||
)
|
||||
|
||||
# Validate/set dimensions
|
||||
embedding_dim = embeddings.shape[1]
|
||||
if self.dimensions is None:
|
||||
self.dimensions = embedding_dim
|
||||
elif self.dimensions != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||
)
|
||||
|
||||
# Ensure we have text data for each embedding
|
||||
if len(self.chunks) != len(ids):
|
||||
# If no text chunks provided, create placeholder text entries
|
||||
if not self.chunks:
|
||||
print("No text chunks provided, creating placeholder entries...")
|
||||
for id_val in ids:
|
||||
self.add_text(
|
||||
f"Document {id_val}",
|
||||
metadata={"id": str(id_val), "from_embeddings": True},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
|
||||
)
|
||||
|
||||
# Build file structure
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||
|
||||
# Write passages and create offset map
|
||||
offset_map = {}
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
for chunk in self.chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk["metadata"],
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
# Build the vector index using precomputed embeddings
|
||||
string_ids = [str(id_val) for id_val in ids]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path)
|
||||
|
||||
# Create metadata file
|
||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||
meta_data = {
|
||||
"version": "1.0",
|
||||
"backend_name": self.backend_name,
|
||||
"embedding_model": self.embedding_model,
|
||||
"dimensions": self.dimensions,
|
||||
"backend_kwargs": self.backend_kwargs,
|
||||
"embedding_mode": self.embedding_mode,
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
}
|
||||
],
|
||||
"built_from_precomputed_embeddings": True,
|
||||
"embeddings_source": str(embeddings_file),
|
||||
}
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = is_compact and is_recompute
|
||||
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
@@ -425,7 +568,9 @@ class LeannSearcher:
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
# Backward compatibility with use_mlx
|
||||
if self.meta_data.get("use_mlx", False):
|
||||
self.embedding_mode = "mlx"
|
||||
@@ -457,6 +602,7 @@ class LeannSearcher:
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
@@ -556,7 +702,7 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
ans=self.llm.ask(prompt, **llm_kwargs)
|
||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
|
||||
Reference in New Issue
Block a user