Compare commits

..

11 Commits

Author SHA1 Message Date
Andy Lee
ed15776564 push 2025-11-24 08:07:51 +00:00
Andy Lee
8d202b8b0e chore 2025-11-24 08:05:51 +00:00
Andy Lee
9ac9eab48d fix 2025-11-24 08:05:29 +00:00
Andy Lee
cd1d853a46 fix 2025-11-24 08:01:42 +00:00
Andy Lee
253680043a fix: recompute args in searcher 2025-11-24 07:58:20 +00:00
Andy Lee
36c44b8806 fix: faster embed 2025-11-24 05:30:11 +00:00
Andy Lee
66c6aad3e4 refactor: embedding server 2025-11-19 06:54:10 +00:00
Andy Lee
29ef3c95dc refactor: embedding server 2025-11-19 06:50:39 +00:00
yichuan-w
469dce0045 add test 2025-11-12 08:46:05 +00:00
CalebZ9909
0ac676f9cb Add reproduction test script for Issue #159
- Test script to reproduce slow search performance issue
- Generates ~90K chunks (~180MB) similar to user's dataset
- Tests search performance with different complexity values (8, 16, 32, 64)
- Demonstrates that complexity=16-32 achieves ~2s search time
- Validates the performance analysis findings
2025-11-12 08:08:34 +00:00
CalebZ9909
97c9f39704 Add performance analysis and reproduction script for Issue #159
- Reproduced the slow search performance issue (15-30s vs expected ~2s)
- Identified root cause: default complexity=64 is too high for fast search
- Created test script demonstrating performance with different complexity values
- Test results show complexity=16-32 achieves ~2s search time (matching paper)
- Added comprehensive analysis document with solutions and recommendations

Key findings:
- Default complexity=64 results in ~36s search time
- Reducing complexity to 16-32 achieves ~2s search time
- beam_width parameter is mainly for DiskANN, not HNSW
- Paper likely used smaller embedding model (~100M) and lower complexity

Solutions provided:
1. Reduce complexity parameter to 16-32 for faster search
2. Consider DiskANN backend for better performance on large datasets
3. Use smaller embedding model if speed is critical
2025-11-12 08:03:48 +00:00
12 changed files with 417 additions and 251 deletions

View File

@@ -14,6 +14,6 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2 - uses: lycheeverse/lychee-action@v2
with: with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/ args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,4 +0,0 @@
# Exclude star-history API from link checking
# This service is intermittently unavailable (503 errors)
# but the link still works when the service is up
.*api\.star-history\.com.*

View File

@@ -7,7 +7,6 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
flexible message processing options. flexible message processing options.
""" """
import ast
import asyncio import asyncio
import json import json
import logging import logging
@@ -147,16 +146,16 @@ class SlackMCPReader:
match = re.search(r"'error':\s*(\{[^}]+\})", str(e)) match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = ast.literal_eval(match.group(1)) error_dict = eval(match.group(1))
except (ValueError, SyntaxError): except (ValueError, SyntaxError, NameError):
pass pass
else: else:
# Try alternative format # Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e)) match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = ast.literal_eval(match.group(1)) error_dict = eval(match.group(1))
except (ValueError, SyntaxError): except (ValueError, SyntaxError, NameError):
pass pass
if self._is_cache_sync_error(error_dict): if self._is_cache_sync_error(error_dict):

98
benchmarks/issue_159.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
"""
Test script to reproduce issue #159: Slow search performance
Configuration:
- GPU: A10
- embedding_model: BAAI/bge-large-zh-v1.5
- data size: 180M text (~90K chunks)
- backend: hnsw
"""
import os
import time
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
# Configuration matching the issue
INDEX_PATH = "./test_issue_159.leann"
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
BACKEND_NAME = "hnsw"
def generate_test_data(num_chunks=90000, chunk_size=2000):
"""Generate test data similar to 180MB text (~90K chunks)"""
# Each chunk is approximately 2000 characters
# 90K chunks * 2000 chars ≈ 180MB
chunks = []
base_text = (
"这是一个测试文档。LEANN是一个创新的向量数据库, 通过图基选择性重计算实现97%的存储节省。"
)
for i in range(num_chunks):
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
chunks.append(chunk[:chunk_size])
return chunks
def test_search_performance():
"""Test search performance with different configurations"""
print("=" * 80)
print("Testing LEANN Search Performance (Issue #159)")
print("=" * 80)
meta_path = Path(f"{INDEX_PATH}.meta.json")
if meta_path.exists():
print(f"\n✓ Index already exists at {INDEX_PATH}")
print(" Skipping build phase. Delete the index to rebuild.")
else:
print("\n📦 Building index...")
print(f" Backend: {BACKEND_NAME}")
print(f" Embedding Model: {EMBEDDING_MODEL}")
print(" Generating test data (~90K chunks, ~180MB)...")
chunks = generate_test_data(num_chunks=90000)
print(f" Generated {len(chunks)} chunks")
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
builder = LeannBuilder(
backend_name=BACKEND_NAME,
embedding_model=EMBEDDING_MODEL,
)
print(" Adding chunks to builder...")
start_time = time.time()
for i, chunk in enumerate(chunks):
builder.add_text(chunk)
if (i + 1) % 10000 == 0:
print(f" Added {i + 1}/{len(chunks)} chunks...")
print(" Building index...")
build_start = time.time()
builder.build_index(INDEX_PATH)
build_time = time.time() - build_start
print(f" ✓ Index built in {build_time:.2f} seconds")
# Test search with different complexity values
print("\n🔍 Testing search performance...")
searcher = LeannSearcher(INDEX_PATH)
test_query = "LEANN向量数据库存储优化"
# Test with minimal complexity (8)
print("\n Test 4: Minimal complexity (8)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=8)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
print("\n" + "=" * 80)
if __name__ == "__main__":
test_search_performance()

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.3.5" version = "0.3.4"
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -143,8 +143,6 @@ def create_hnsw_embedding_server(
pass pass
return str(nid) return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal. """ZMQ server thread that respects shutdown signal.
@@ -158,225 +156,238 @@ def create_hnsw_embedding_server(
rep_socket.bind(f"tcp://*:{zmq_port}") rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000) rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0) rep_socket.setsockopt(zmq.LINGER, 0)
# Track last request type/length for shape-correct fallbacks last_request_type = "unknown"
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0 last_request_length = 0
def _build_safe_fallback():
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
return [[large_distance] * fallback_len]
if last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
if dim > 0:
return [[bsz, dim], [0.0] * (bsz * dim)]
return [[0, 0], []]
if last_request_type == "text":
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else:
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
e2e_start = time.time()
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as exc:
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try: try:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...") logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv() request_bytes = rep_socket.recv()
except zmq.Again:
continue
# Rest of the processing logic (same as original) try:
request = msgpack.unpackb(request_bytes) request = msgpack.unpackb(request_bytes)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
if len(request) == 1 and request[0] == "__QUERY_MODEL__": try:
response_bytes = msgpack.packb([model_name]) # Model query
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if ( if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list) isinstance(request, list)
and request and request
and all(isinstance(item, str) for item in request) and all(isinstance(item, str) for item in request)
): ):
last_request_type = "text" _handle_text_embedding(request)
last_request_length = len(request) # Distance calculation: [[ids], [query_vector]]
embeddings = compute_embeddings( elif (
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list) isinstance(request, list)
and len(request) == 2 and len(request) == 2
and isinstance(request[0], list) and isinstance(request[0], list)
and isinstance(request[1], list) and isinstance(request[1], list)
): ):
node_ids = request[0] _handle_distance_request(request)
# Handle nested [[ids]] shape defensively # Embedding-by-id fallback
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
rep_socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else: else:
_handle_embedding_by_id(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error") logger.info("Shutdown in progress, ignoring ZMQ error")
break break
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally: finally:
try: try:
rep_socket.close(0) rep_socket.close(0)

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.3.5" version = "0.3.4"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.3.5", "leann-core==0.3.4",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.3.5" version = "0.3.4"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -864,7 +864,13 @@ class LeannBuilder:
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(
self,
index_path: str,
enable_warmup: bool = True,
recompute_embeddings: bool = True,
**backend_kwargs,
):
# Fix path resolution for Colab and other environments # Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute(): if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve()) index_path = str(Path(index_path).resolve())
@@ -895,14 +901,32 @@ class LeannSearcher:
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
# Global recompute flag for this searcher (explicit knob, default True)
self.recompute_embeddings: bool = bool(recompute_embeddings)
# Warmup flag: keep using the existing enable_warmup parameter,
# but default it to True so cold-start happens earlier.
self._warmup: bool = bool(enable_warmup)
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup final_kwargs["enable_warmup"] = self._warmup
if self.embedding_options: if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options) final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs index_path, **final_kwargs
) )
# Optional one-shot warmup at construction time to hide cold-start latency.
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")
def search( def search(
self, self,
query: str, query: str,
@@ -910,7 +934,7 @@ class LeannSearcher:
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = True, recompute_embeddings: Optional[bool] = None,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
@@ -927,7 +951,8 @@ class LeannSearcher:
complexity: Search complexity/candidate list size, higher = more accurate but slower complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes recompute_embeddings: (Deprecated) Per-call override for recompute mode.
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata. metadata_filters: Optional filters to apply to search results based on metadata.
@@ -966,8 +991,19 @@ class LeannSearcher:
zmq_port = None zmq_port = None
# Resolve effective recompute flag for this search.
if recompute_embeddings is not None:
logger.warning(
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
"will be removed in a future version. Configure recompute at "
"LeannSearcher(..., recompute_embeddings=...) instead."
)
effective_recompute = bool(recompute_embeddings)
else:
effective_recompute = self.recompute_embeddings
start_time = time.time() start_time = time.time()
if recompute_embeddings: if effective_recompute:
zmq_port = self.backend_impl._ensure_server_running( zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str, self.meta_path_str,
port=expected_zmq_port, port=expected_zmq_port,
@@ -981,7 +1017,7 @@ class LeannSearcher:
query_embedding = self.backend_impl.compute_query_embedding( query_embedding = self.backend_impl.compute_query_embedding(
query, query,
use_server_if_available=recompute_embeddings, use_server_if_available=effective_recompute,
zmq_port=zmq_port, zmq_port=zmq_port,
) )
logger.info(f" Generated embedding shape: {query_embedding.shape}") logger.info(f" Generated embedding shape: {query_embedding.shape}")
@@ -993,7 +1029,7 @@ class LeannSearcher:
"complexity": complexity, "complexity": complexity,
"beam_width": beam_width, "beam_width": beam_width,
"prune_ratio": prune_ratio, "prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings, "recompute_embeddings": effective_recompute,
"pruning_strategy": pruning_strategy, "pruning_strategy": pruning_strategy,
"zmq_port": zmq_port, "zmq_port": zmq_port,
} }

View File

@@ -215,9 +215,14 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
provider_options = provider_options or {} provider_options = provider_options or {}
wrapper_start_time = time.time()
logger.debug(
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
)
if mode == "sentence-transformers": if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers( inner_start_time = time.time()
result = compute_embeddings_sentence_transformers(
texts, texts,
model_name, model_name,
is_build=is_build, is_build=is_build,
@@ -226,6 +231,14 @@ def compute_embeddings(
manual_tokenize=manual_tokenize, manual_tokenize=manual_tokenize,
max_length=max_length, max_length=max_length,
) )
inner_end_time = time.time()
wrapper_end_time = time.time()
logger.debug(
"[compute_embeddings] sentence-transformers timings: "
f"inner={inner_end_time - inner_start_time:.6f}s, "
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
)
return result
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai( return compute_embeddings_openai(
texts, texts,
@@ -271,6 +284,7 @@ def compute_embeddings_sentence_transformers(
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size adaptive_optimization: Whether to use adaptive optimization based on batch size
""" """
outer_start_time = time.time()
# Handle empty input # Handle empty input
if not texts: if not texts:
raise ValueError("Cannot compute embeddings for empty text list") raise ValueError("Cannot compute embeddings for empty text list")
@@ -301,7 +315,14 @@ def compute_embeddings_sentence_transformers(
# Create cache key # Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized" cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
pre_model_init_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers pre-model-init time "
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
)
# Check if model is already cached # Check if model is already cached
start_time = time.time()
if cache_key in _model_cache: if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}") logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key] model = _model_cache[cache_key]
@@ -441,10 +462,13 @@ def compute_embeddings_sentence_transformers(
_model_cache[cache_key] = model _model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}") logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode end_time = time.time()
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})" # Compute embeddings with optimized inference mode
) logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
start_time = time.time() start_time = time.time()
if not manual_tokenize: if not manual_tokenize:
@@ -465,32 +489,46 @@ def compute_embeddings_sentence_transformers(
except Exception: except Exception:
pass pass
else: else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel # Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
# This path is reserved for an aggressively optimized FP pipeline
# (no quantization), mainly for experimentation.
try: try:
from transformers import AutoModel, AutoTokenizer # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e: except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}") raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}" tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}" mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache: if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key] hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key] hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path") logger.info("Using cached HF tokenizer/model for manual FP path")
else: else:
logger.info("Loading HF tokenizer/model for manual tokenization path") logger.info("Loading HF tokenizer/model for manual FP path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype) hf_model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch_dtype,
)
hf_model.to(device) hf_model.to(device)
hf_model.eval() hf_model.eval()
# Optional compile on supported devices # Optional compile on supported devices
if device in ["cuda", "mps"]: if device in ["cuda", "mps"]:
try: try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore hf_model = torch.compile( # type: ignore
except Exception: hf_model, mode="reduce-overhead", dynamic=True
pass )
logger.info(
f"Applied torch.compile to HF model for {model_name} "
f"(device={device}, dtype={torch_dtype})"
)
except Exception as exc:
logger.warning(f"torch.compile optimization failed: {exc}")
_model_cache[tok_cache_key] = hf_tokenizer _model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model _model_cache[mdl_cache_key] = hf_model
@@ -516,7 +554,6 @@ def compute_embeddings_sentence_transformers(
for start_index in batch_iter: for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts)) end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index] batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer( inputs = hf_tokenizer(
batch_texts, batch_texts,
padding=True, padding=True,
@@ -524,34 +561,17 @@ def compute_embeddings_sentence_transformers(
max_length=max_length, max_length=max_length,
return_tensors="pt", return_tensors="pt",
) )
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()} inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs) outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H) last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask") attention_mask = inputs.get("attention_mask")
if attention_mask is None: if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1) pooled = last_hidden_state.mean(dim=1)
else: else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1) lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy() batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings) all_embeddings.append(batch_embeddings)
@@ -571,6 +591,12 @@ def compute_embeddings_sentence_transformers(
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}") raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
outer_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers total time "
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
)
return embeddings return embeddings

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.3.5" version = "0.3.4"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"