Compare commits

..

13 Commits

Author SHA1 Message Date
aakash
cf67a848f4 Fix exclude pattern: use proper regex format for lychee 2025-11-13 12:10:28 -08:00
aakash
31ee48e3c8 Update exclude pattern and accept 503 status codes as fallback 2025-11-13 12:06:52 -08:00
aakash
3b23330bde Use pattern-based exclude for star-history API 2025-11-13 12:06:36 -08:00
aakash
1ad9f75e96 Fix lychee exclude argument: use --exclude instead of --exclude-url 2025-11-13 12:02:11 -08:00
aakash
15d15f8881 Exclude star-history API URL directly in lychee args 2025-11-13 11:42:19 -08:00
aakash
d70e8fe40c Revert --accept 503 flag, rely on .lycheeignore instead 2025-11-13 11:42:03 -08:00
aakash
e8c4ccde53 Configure lychee to accept 503 status codes for intermittently unavailable services 2025-11-13 11:28:01 -08:00
aakash
d27970538a Fix .lycheeignore: ensure exactly one trailing newline 2025-11-13 11:26:54 -08:00
aakash
dab299043d Fix .lycheeignore formatting and sync uv.lock from main 2025-11-13 11:26:47 -08:00
aakash
620da9dc27 Fix .lycheeignore formatting (add trailing newline) 2025-11-13 11:24:29 -08:00
aakash
27d5a49f94 Add .lycheeignore to exclude intermittently unavailable star-history API 2025-11-13 11:22:41 -08:00
aakash
043e32d959 Fix security vulnerability: Replace eval() with ast.literal_eval() in slack_mcp_reader.py
Fixes #163: Replace unsafe eval() calls with ast.literal_eval() to prevent code injection attacks. ast.literal_eval() safely evaluates only Python literals, preventing arbitrary code execution.
2025-11-13 11:18:07 -08:00
GitHub Actions
3c4785bb63 chore: release v0.3.5 2025-11-12 06:01:25 +00:00
12 changed files with 251 additions and 417 deletions

View File

@@ -14,6 +14,6 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2 - uses: lycheeverse/lychee-action@v2
with: with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/ args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

4
.lycheeignore Normal file
View File

@@ -0,0 +1,4 @@
# Exclude star-history API from link checking
# This service is intermittently unavailable (503 errors)
# but the link still works when the service is up
.*api\.star-history\.com.*

View File

@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
flexible message processing options. flexible message processing options.
""" """
import ast
import asyncio import asyncio
import json import json
import logging import logging
@@ -146,16 +147,16 @@ class SlackMCPReader:
match = re.search(r"'error':\s*(\{[^}]+\})", str(e)) match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = eval(match.group(1)) error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError, NameError): except (ValueError, SyntaxError):
pass pass
else: else:
# Try alternative format # Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e)) match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match: if match:
try: try:
error_dict = eval(match.group(1)) error_dict = ast.literal_eval(match.group(1))
except (ValueError, SyntaxError, NameError): except (ValueError, SyntaxError):
pass pass
if self._is_cache_sync_error(error_dict): if self._is_cache_sync_error(error_dict):

View File

@@ -1,98 +0,0 @@
#!/usr/bin/env python3
"""
Test script to reproduce issue #159: Slow search performance
Configuration:
- GPU: A10
- embedding_model: BAAI/bge-large-zh-v1.5
- data size: 180M text (~90K chunks)
- backend: hnsw
"""
import os
import time
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
# Configuration matching the issue
INDEX_PATH = "./test_issue_159.leann"
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
BACKEND_NAME = "hnsw"
def generate_test_data(num_chunks=90000, chunk_size=2000):
"""Generate test data similar to 180MB text (~90K chunks)"""
# Each chunk is approximately 2000 characters
# 90K chunks * 2000 chars ≈ 180MB
chunks = []
base_text = (
"这是一个测试文档。LEANN是一个创新的向量数据库, 通过图基选择性重计算实现97%的存储节省。"
)
for i in range(num_chunks):
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
chunks.append(chunk[:chunk_size])
return chunks
def test_search_performance():
"""Test search performance with different configurations"""
print("=" * 80)
print("Testing LEANN Search Performance (Issue #159)")
print("=" * 80)
meta_path = Path(f"{INDEX_PATH}.meta.json")
if meta_path.exists():
print(f"\n✓ Index already exists at {INDEX_PATH}")
print(" Skipping build phase. Delete the index to rebuild.")
else:
print("\n📦 Building index...")
print(f" Backend: {BACKEND_NAME}")
print(f" Embedding Model: {EMBEDDING_MODEL}")
print(" Generating test data (~90K chunks, ~180MB)...")
chunks = generate_test_data(num_chunks=90000)
print(f" Generated {len(chunks)} chunks")
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
builder = LeannBuilder(
backend_name=BACKEND_NAME,
embedding_model=EMBEDDING_MODEL,
)
print(" Adding chunks to builder...")
start_time = time.time()
for i, chunk in enumerate(chunks):
builder.add_text(chunk)
if (i + 1) % 10000 == 0:
print(f" Added {i + 1}/{len(chunks)} chunks...")
print(" Building index...")
build_start = time.time()
builder.build_index(INDEX_PATH)
build_time = time.time() - build_start
print(f" ✓ Index built in {build_time:.2f} seconds")
# Test search with different complexity values
print("\n🔍 Testing search performance...")
searcher = LeannSearcher(INDEX_PATH)
test_query = "LEANN向量数据库存储优化"
# Test with minimal complexity (8)
print("\n Test 4: Minimal complexity (8)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=8)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
print("\n" + "=" * 80)
if __name__ == "__main__":
test_search_performance()

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.3.4" version = "0.3.5"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -143,6 +143,8 @@ def create_hnsw_embedding_server(
pass pass
return str(nid) return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal. """ZMQ server thread that respects shutdown signal.
@@ -156,238 +158,225 @@ def create_hnsw_embedding_server(
rep_socket.bind(f"tcp://*:{zmq_port}") rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000) rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0) rep_socket.setsockopt(zmq.LINGER, 0)
last_request_type = "unknown" # Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0 last_request_length = 0
def _build_safe_fallback():
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
return [[large_distance] * fallback_len]
if last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
if dim > 0:
return [[bsz, dim], [0.0] * (bsz * dim)]
return [[0, 0], []]
if last_request_type == "text":
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else:
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
e2e_start = time.time()
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as exc:
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try: try:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...") logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv() request_bytes = rep_socket.recv()
except zmq.Again:
continue
try: # Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes) request = msgpack.unpackb(request_bytes)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
try: if len(request) == 1 and request[0] == "__QUERY_MODEL__":
# Model query response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if ( if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list) isinstance(request, list)
and request and request
and all(isinstance(item, str) for item in request) and all(isinstance(item, str) for item in request)
): ):
_handle_text_embedding(request) last_request_type = "text"
# Distance calculation: [[ids], [query_vector]] last_request_length = len(request)
elif ( embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list) isinstance(request, list)
and len(request) == 2 and len(request) == 2
and isinstance(request[0], list) and isinstance(request[0], list)
and isinstance(request[1], list) and isinstance(request[1], list)
): ):
_handle_distance_request(request) node_ids = request[0]
# Embedding-by-id fallback # Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
rep_socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else: else:
_handle_embedding_by_id(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error") logger.info("Shutdown in progress, ignoring ZMQ error")
break break
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally: finally:
try: try:
rep_socket.close(0) rep_socket.close(0)

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.3.4" version = "0.3.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.3.4", "leann-core==0.3.5",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.3.4" version = "0.3.5"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -864,13 +864,7 @@ class LeannBuilder:
class LeannSearcher: class LeannSearcher:
def __init__( def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
self,
index_path: str,
enable_warmup: bool = True,
recompute_embeddings: bool = True,
**backend_kwargs,
):
# Fix path resolution for Colab and other environments # Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute(): if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve()) index_path = str(Path(index_path).resolve())
@@ -901,32 +895,14 @@ class LeannSearcher:
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
# Global recompute flag for this searcher (explicit knob, default True)
self.recompute_embeddings: bool = bool(recompute_embeddings)
# Warmup flag: keep using the existing enable_warmup parameter,
# but default it to True so cold-start happens earlier.
self._warmup: bool = bool(enable_warmup)
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = self._warmup final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options: if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options) final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs index_path, **final_kwargs
) )
# Optional one-shot warmup at construction time to hide cold-start latency.
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")
def search( def search(
self, self,
query: str, query: str,
@@ -934,7 +910,7 @@ class LeannSearcher:
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: Optional[bool] = None, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
@@ -951,8 +927,7 @@ class LeannSearcher:
complexity: Search complexity/candidate list size, higher = more accurate but slower complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: (Deprecated) Per-call override for recompute mode. recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata. metadata_filters: Optional filters to apply to search results based on metadata.
@@ -991,19 +966,8 @@ class LeannSearcher:
zmq_port = None zmq_port = None
# Resolve effective recompute flag for this search.
if recompute_embeddings is not None:
logger.warning(
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
"will be removed in a future version. Configure recompute at "
"LeannSearcher(..., recompute_embeddings=...) instead."
)
effective_recompute = bool(recompute_embeddings)
else:
effective_recompute = self.recompute_embeddings
start_time = time.time() start_time = time.time()
if effective_recompute: if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running( zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str, self.meta_path_str,
port=expected_zmq_port, port=expected_zmq_port,
@@ -1017,7 +981,7 @@ class LeannSearcher:
query_embedding = self.backend_impl.compute_query_embedding( query_embedding = self.backend_impl.compute_query_embedding(
query, query,
use_server_if_available=effective_recompute, use_server_if_available=recompute_embeddings,
zmq_port=zmq_port, zmq_port=zmq_port,
) )
logger.info(f" Generated embedding shape: {query_embedding.shape}") logger.info(f" Generated embedding shape: {query_embedding.shape}")
@@ -1029,7 +993,7 @@ class LeannSearcher:
"complexity": complexity, "complexity": complexity,
"beam_width": beam_width, "beam_width": beam_width,
"prune_ratio": prune_ratio, "prune_ratio": prune_ratio,
"recompute_embeddings": effective_recompute, "recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy, "pruning_strategy": pruning_strategy,
"zmq_port": zmq_port, "zmq_port": zmq_port,
} }

View File

@@ -215,14 +215,9 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
provider_options = provider_options or {} provider_options = provider_options or {}
wrapper_start_time = time.time()
logger.debug(
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
)
if mode == "sentence-transformers": if mode == "sentence-transformers":
inner_start_time = time.time() return compute_embeddings_sentence_transformers(
result = compute_embeddings_sentence_transformers(
texts, texts,
model_name, model_name,
is_build=is_build, is_build=is_build,
@@ -231,14 +226,6 @@ def compute_embeddings(
manual_tokenize=manual_tokenize, manual_tokenize=manual_tokenize,
max_length=max_length, max_length=max_length,
) )
inner_end_time = time.time()
wrapper_end_time = time.time()
logger.debug(
"[compute_embeddings] sentence-transformers timings: "
f"inner={inner_end_time - inner_start_time:.6f}s, "
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
)
return result
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai( return compute_embeddings_openai(
texts, texts,
@@ -284,7 +271,6 @@ def compute_embeddings_sentence_transformers(
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size adaptive_optimization: Whether to use adaptive optimization based on batch size
""" """
outer_start_time = time.time()
# Handle empty input # Handle empty input
if not texts: if not texts:
raise ValueError("Cannot compute embeddings for empty text list") raise ValueError("Cannot compute embeddings for empty text list")
@@ -315,14 +301,7 @@ def compute_embeddings_sentence_transformers(
# Create cache key # Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized" cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
pre_model_init_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers pre-model-init time "
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
)
# Check if model is already cached # Check if model is already cached
start_time = time.time()
if cache_key in _model_cache: if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}") logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key] model = _model_cache[cache_key]
@@ -462,13 +441,10 @@ def compute_embeddings_sentence_transformers(
_model_cache[cache_key] = model _model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}") logger.info(f"Model cached: {cache_key}")
end_time = time.time() # Compute embeddings with optimized inference mode
logger.info(
# Compute embeddings with optimized inference mode f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
logger.info( )
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
start_time = time.time() start_time = time.time()
if not manual_tokenize: if not manual_tokenize:
@@ -489,46 +465,32 @@ def compute_embeddings_sentence_transformers(
except Exception: except Exception:
pass pass
else: else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel. # Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
# This path is reserved for an aggressively optimized FP pipeline
# (no quantization), mainly for experimentation.
try: try:
from transformers import AutoModel, AutoTokenizer # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e: except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}") raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}" tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp" mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache: if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key] hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key] hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual FP path") logger.info("Using cached HF tokenizer/model for manual path")
else: else:
logger.info("Loading HF tokenizer/model for manual FP path") logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained( hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
model_name,
torch_dtype=torch_dtype,
)
hf_model.to(device) hf_model.to(device)
hf_model.eval() hf_model.eval()
# Optional compile on supported devices # Optional compile on supported devices
if device in ["cuda", "mps"]: if device in ["cuda", "mps"]:
try: try:
hf_model = torch.compile( # type: ignore hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
hf_model, mode="reduce-overhead", dynamic=True except Exception:
) pass
logger.info(
f"Applied torch.compile to HF model for {model_name} "
f"(device={device}, dtype={torch_dtype})"
)
except Exception as exc:
logger.warning(f"torch.compile optimization failed: {exc}")
_model_cache[tok_cache_key] = hf_tokenizer _model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model _model_cache[mdl_cache_key] = hf_model
@@ -554,6 +516,7 @@ def compute_embeddings_sentence_transformers(
for start_index in batch_iter: for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts)) end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index] batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer( inputs = hf_tokenizer(
batch_texts, batch_texts,
padding=True, padding=True,
@@ -561,17 +524,34 @@ def compute_embeddings_sentence_transformers(
max_length=max_length, max_length=max_length,
return_tensors="pt", return_tensors="pt",
) )
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()} inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs) outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H) last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask") attention_mask = inputs.get("attention_mask")
if attention_mask is None: if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1) pooled = last_hidden_state.mean(dim=1)
else: else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1) lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy() batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings) all_embeddings.append(batch_embeddings)
@@ -591,12 +571,6 @@ def compute_embeddings_sentence_transformers(
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}") raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
outer_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers total time "
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
)
return embeddings return embeddings

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.3.4" version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"