Compare commits

..

1 Commits

Author SHA1 Message Date
GitHub Actions
3c4785bb63 chore: release v0.3.5 2025-11-12 06:01:25 +00:00
10 changed files with 241 additions and 573 deletions

View File

@@ -1,110 +0,0 @@
# Issue #159 Performance Analysis - Conclusion
## Problem Summary
User reported search times of 15-30 seconds instead of the ~2 seconds mentioned in the paper.
**Configuration:**
- GPU: 4090×1
- Embedding Model: BAAI/bge-large-zh-v1.5 (~300M parameters)
- Data Size: 180MB text (~90K chunks)
- Backend: HNSW
- beam_width: 10
- Other parameters: Default values
## Root Cause Analysis
### 1. **Search Complexity Parameter**
The **default `complexity` parameter is 64**, which is too high for achieving ~2 second search times with this configuration.
**Test Results (Reproduced):**
- **Complexity 64 (default)**: **36.17 seconds**
- **Complexity 32**: **2.49 seconds**
- **Complexity 16**: **2.24 seconds** ✅ (Close to paper's ~2 seconds)
- **Complexity 8**: **1.67 seconds**
### 2. **beam_width Parameter**
The `beam_width` parameter is **mainly for DiskANN backend**, not HNSW. Setting it to 10 has minimal/no effect on HNSW search performance.
### 3. **Embedding Model Size**
The paper uses a smaller embedding model (~100M parameters), while the user is using `BAAI/bge-large-zh-v1.5` (~300M parameters). This contributes to slower embedding computation during search, but the main bottleneck is the search complexity parameter.
## Solution
### **Recommended Fix: Reduce Search Complexity**
To achieve search times close to ~2 seconds, use:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
query="your query",
top_k=10,
complexity=16, # or complexity=32 for slightly better accuracy
# beam_width parameter doesn't affect HNSW, can be ignored
)
```
Or via CLI:
```bash
leann search your-index "your query" --complexity 16
```
### **Alternative Solutions**
1. **Use DiskANN Backend** (Recommended by maintainer)
- DiskANN is faster for large datasets
- Better performance scaling
- `beam_width` parameter is relevant here
```python
builder = LeannBuilder(backend_name="diskann")
```
2. **Use Smaller Embedding Model**
- Switch to a smaller model (~100M parameters) like the paper
- Faster embedding computation
- Example: `BAAI/bge-base-zh-v1.5` instead of `bge-large-zh-v1.5`
3. **Disable Recomputation** (Trade storage for speed)
- Use `--no-recompute` flag
- Stores all embeddings (much larger storage)
- Faster search (no embedding recomputation)
```bash
leann build your-index --no-recompute --no-compact
leann search your-index "query" --no-recompute
```
## Performance Comparison
| Complexity | Search Time | Accuracy | Recommendation |
|------------|-------------|----------|---------------|
| 64 (default) | ~36s | Highest | ❌ Too slow |
| 32 | ~2.5s | High | ✅ Good balance |
| 16 | ~2.2s | Good | ✅ **Recommended** (matches paper) |
| 8 | ~1.7s | Lower | ⚠️ May sacrifice accuracy |
## Key Takeaways
1. **The default `complexity=64` is optimized for accuracy, not speed**
2. **For ~2 second search times, use `complexity=16` or `complexity=32`**
3. **`beam_width` parameter is for DiskANN, not HNSW**
4. **The paper's ~2 second results likely used:**
- Smaller embedding model (~100M params)
- Lower complexity (16-32)
- Possibly DiskANN backend
## Verification
The issue has been reproduced and verified. The test script `test_issue_159.py` demonstrates:
- Default complexity (64) results in ~36 second search times
- Reducing complexity to 16-32 achieves ~2 second search times
- This matches the user's reported issue and provides a clear solution
## Next Steps
1. ✅ Issue reproduced and root cause identified
2. ✅ Solution provided (reduce complexity parameter)
3. ⏳ User should test with `complexity=16` or `complexity=32`
4. ⏳ Consider updating documentation to clarify complexity parameter trade-offs

View File

@@ -1,149 +0,0 @@
#!/usr/bin/env python3
"""
Test script to reproduce issue #159: Slow search performance
Configuration:
- GPU: 4090×1
- embedding_model: BAAI/bge-large-zh-v1.5
- data size: 180M text (~90K chunks)
- beam_width: 10 (though this is mainly for DiskANN, not HNSW)
- backend: hnsw
"""
import os
import time
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher, SearchResult
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
# Configuration matching the issue
INDEX_PATH = "./test_issue_159.leann"
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
BACKEND_NAME = "hnsw"
def generate_test_data(num_chunks=90000, chunk_size=2000):
"""Generate test data similar to 180MB text (~90K chunks)"""
# Each chunk is approximately 2000 characters
# 90K chunks * 2000 chars ≈ 180MB
chunks = []
base_text = (
"这是一个测试文档。LEANN是一个创新的向量数据库通过图基选择性重计算实现97%的存储节省。"
)
for i in range(num_chunks):
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
chunks.append(chunk[:chunk_size])
return chunks
def test_search_performance():
"""Test search performance with different configurations"""
print("=" * 80)
print("Testing LEANN Search Performance (Issue #159)")
print("=" * 80)
meta_path = Path(f"{INDEX_PATH}.meta.json")
if meta_path.exists():
print(f"\n✓ Index already exists at {INDEX_PATH}")
print(" Skipping build phase. Delete the index to rebuild.")
else:
print("\n📦 Building index...")
print(f" Backend: {BACKEND_NAME}")
print(f" Embedding Model: {EMBEDDING_MODEL}")
print(" Generating test data (~90K chunks, ~180MB)...")
chunks = generate_test_data(num_chunks=90000)
print(f" Generated {len(chunks)} chunks")
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
builder = LeannBuilder(
backend_name=BACKEND_NAME,
embedding_model=EMBEDDING_MODEL,
)
print(" Adding chunks to builder...")
start_time = time.time()
for i, chunk in enumerate(chunks):
builder.add_text(chunk)
if (i + 1) % 10000 == 0:
print(f" Added {i + 1}/{len(chunks)} chunks...")
print(" Building index...")
build_start = time.time()
builder.build_index(INDEX_PATH)
build_time = time.time() - build_start
print(f" ✓ Index built in {build_time:.2f} seconds")
# Test search with different complexity values
print("\n🔍 Testing search performance...")
searcher = LeannSearcher(INDEX_PATH)
test_query = "LEANN向量数据库存储优化"
# Test with default complexity (64)
print("\n Test 1: Default complexity (64) `1 ")
print(f" Query: '{test_query}'")
start_time = time.time()
results: list[SearchResult] = searcher.search(test_query, top_k=10, complexity=64)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with default complexity (64)
print("\n Test 1: Default complexity (64)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=64)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with lower complexity (32)
print("\n Test 2: Lower complexity (32)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=32)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with even lower complexity (16)
print("\n Test 3: Lower complexity (16)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=16)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with minimal complexity (8)
print("\n Test 4: Minimal complexity (8)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=8)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
print("\n" + "=" * 80)
print("Performance Analysis:")
print("=" * 80)
print("\nKey Findings:")
print("1. beam_width parameter is mainly for DiskANN backend, not HNSW")
print("2. For HNSW, the main parameter affecting search speed is 'complexity'")
print("3. Lower complexity values (16-32) should provide faster search")
print("4. The paper mentions ~2 seconds, which likely uses:")
print(" - Smaller embedding model (~100M params vs 300M for bge-large)")
print(" - Lower complexity (16-32)")
print(" - Possibly DiskANN backend for better performance")
print("\nRecommendations:")
print("- Try complexity=16 or complexity=32 for faster search")
print("- Consider using DiskANN backend for better performance on large datasets")
print("- Or use a smaller embedding model if speed is critical")
if __name__ == "__main__":
test_search_performance()

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.3.4" version = "0.3.5"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -143,6 +143,8 @@ def create_hnsw_embedding_server(
pass pass
return str(nid) return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal. """ZMQ server thread that respects shutdown signal.
@@ -156,238 +158,225 @@ def create_hnsw_embedding_server(
rep_socket.bind(f"tcp://*:{zmq_port}") rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000) rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0) rep_socket.setsockopt(zmq.LINGER, 0)
last_request_type = "unknown" # Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0 last_request_length = 0
def _build_safe_fallback():
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
return [[large_distance] * fallback_len]
if last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
if dim > 0:
return [[bsz, dim], [0.0] * (bsz * dim)]
return [[0, 0], []]
if last_request_type == "text":
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else:
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
e2e_start = time.time()
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as exc:
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try: try:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...") logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv() request_bytes = rep_socket.recv()
except zmq.Again:
continue
try: # Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes) request = msgpack.unpackb(request_bytes)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
try: if len(request) == 1 and request[0] == "__QUERY_MODEL__":
# Model query response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if ( if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list) isinstance(request, list)
and request and request
and all(isinstance(item, str) for item in request) and all(isinstance(item, str) for item in request)
): ):
_handle_text_embedding(request) last_request_type = "text"
# Distance calculation: [[ids], [query_vector]] last_request_length = len(request)
elif ( embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list) isinstance(request, list)
and len(request) == 2 and len(request) == 2
and isinstance(request[0], list) and isinstance(request[0], list)
and isinstance(request[1], list) and isinstance(request[1], list)
): ):
_handle_distance_request(request) node_ids = request[0]
# Embedding-by-id fallback # Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
rep_socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else: else:
_handle_embedding_by_id(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error") logger.info("Shutdown in progress, ignoring ZMQ error")
break break
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally: finally:
try: try:
rep_socket.close(0) rep_socket.close(0)

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.3.4" version = "0.3.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.3.4", "leann-core==0.3.5",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.3.4" version = "0.3.5"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -864,13 +864,7 @@ class LeannBuilder:
class LeannSearcher: class LeannSearcher:
def __init__( def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
self,
index_path: str,
enable_warmup: bool = True,
recompute_embeddings: bool = True,
**backend_kwargs,
):
# Fix path resolution for Colab and other environments # Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute(): if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve()) index_path = str(Path(index_path).resolve())
@@ -901,32 +895,14 @@ class LeannSearcher:
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
# Global recompute flag for this searcher (explicit knob, default True)
self.recompute_embeddings: bool = bool(recompute_embeddings)
# Warmup flag: keep using the existing enable_warmup parameter,
# but default it to True so cold-start happens earlier.
self._warmup: bool = bool(enable_warmup)
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = self._warmup final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options: if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options) final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs index_path, **final_kwargs
) )
# Optional one-shot warmup at construction time to hide cold-start latency.
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")
def search( def search(
self, self,
query: str, query: str,
@@ -934,7 +910,7 @@ class LeannSearcher:
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: Optional[bool] = None, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
@@ -951,8 +927,7 @@ class LeannSearcher:
complexity: Search complexity/candidate list size, higher = more accurate but slower complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: (Deprecated) Per-call override for recompute mode. recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata. metadata_filters: Optional filters to apply to search results based on metadata.
@@ -991,19 +966,8 @@ class LeannSearcher:
zmq_port = None zmq_port = None
# Resolve effective recompute flag for this search.
if recompute_embeddings is not None:
logger.warning(
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
"will be removed in a future version. Configure recompute at "
"LeannSearcher(..., recompute_embeddings=...) instead."
)
effective_recompute = bool(recompute_embeddings)
else:
effective_recompute = self.recompute_embeddings
start_time = time.time() start_time = time.time()
if effective_recompute: if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running( zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str, self.meta_path_str,
port=expected_zmq_port, port=expected_zmq_port,
@@ -1017,7 +981,7 @@ class LeannSearcher:
query_embedding = self.backend_impl.compute_query_embedding( query_embedding = self.backend_impl.compute_query_embedding(
query, query,
use_server_if_available=effective_recompute, use_server_if_available=recompute_embeddings,
zmq_port=zmq_port, zmq_port=zmq_port,
) )
logger.info(f" Generated embedding shape: {query_embedding.shape}") logger.info(f" Generated embedding shape: {query_embedding.shape}")
@@ -1029,7 +993,7 @@ class LeannSearcher:
"complexity": complexity, "complexity": complexity,
"beam_width": beam_width, "beam_width": beam_width,
"prune_ratio": prune_ratio, "prune_ratio": prune_ratio,
"recompute_embeddings": effective_recompute, "recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy, "pruning_strategy": pruning_strategy,
"zmq_port": zmq_port, "zmq_port": zmq_port,
} }

View File

@@ -215,14 +215,9 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
provider_options = provider_options or {} provider_options = provider_options or {}
wrapper_start_time = time.time()
logger.debug(
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
)
if mode == "sentence-transformers": if mode == "sentence-transformers":
inner_start_time = time.time() return compute_embeddings_sentence_transformers(
result = compute_embeddings_sentence_transformers(
texts, texts,
model_name, model_name,
is_build=is_build, is_build=is_build,
@@ -231,14 +226,6 @@ def compute_embeddings(
manual_tokenize=manual_tokenize, manual_tokenize=manual_tokenize,
max_length=max_length, max_length=max_length,
) )
inner_end_time = time.time()
wrapper_end_time = time.time()
logger.debug(
"[compute_embeddings] sentence-transformers timings: "
f"inner={inner_end_time - inner_start_time:.6f}s, "
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
)
return result
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai( return compute_embeddings_openai(
texts, texts,
@@ -284,7 +271,6 @@ def compute_embeddings_sentence_transformers(
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size adaptive_optimization: Whether to use adaptive optimization based on batch size
""" """
outer_start_time = time.time()
# Handle empty input # Handle empty input
if not texts: if not texts:
raise ValueError("Cannot compute embeddings for empty text list") raise ValueError("Cannot compute embeddings for empty text list")
@@ -315,14 +301,7 @@ def compute_embeddings_sentence_transformers(
# Create cache key # Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized" cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
pre_model_init_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers pre-model-init time "
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
)
# Check if model is already cached # Check if model is already cached
start_time = time.time()
if cache_key in _model_cache: if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}") logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key] model = _model_cache[cache_key]
@@ -462,13 +441,10 @@ def compute_embeddings_sentence_transformers(
_model_cache[cache_key] = model _model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}") logger.info(f"Model cached: {cache_key}")
end_time = time.time() # Compute embeddings with optimized inference mode
logger.info(
# Compute embeddings with optimized inference mode f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
logger.info( )
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
start_time = time.time() start_time = time.time()
if not manual_tokenize: if not manual_tokenize:
@@ -489,46 +465,32 @@ def compute_embeddings_sentence_transformers(
except Exception: except Exception:
pass pass
else: else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel. # Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
# This path is reserved for an aggressively optimized FP pipeline
# (no quantization), mainly for experimentation.
try: try:
from transformers import AutoModel, AutoTokenizer # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e: except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}") raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}" tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp" mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache: if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key] hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key] hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual FP path") logger.info("Using cached HF tokenizer/model for manual path")
else: else:
logger.info("Loading HF tokenizer/model for manual FP path") logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained( hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
model_name,
torch_dtype=torch_dtype,
)
hf_model.to(device) hf_model.to(device)
hf_model.eval() hf_model.eval()
# Optional compile on supported devices # Optional compile on supported devices
if device in ["cuda", "mps"]: if device in ["cuda", "mps"]:
try: try:
hf_model = torch.compile( # type: ignore hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
hf_model, mode="reduce-overhead", dynamic=True except Exception:
) pass
logger.info(
f"Applied torch.compile to HF model for {model_name} "
f"(device={device}, dtype={torch_dtype})"
)
except Exception as exc:
logger.warning(f"torch.compile optimization failed: {exc}")
_model_cache[tok_cache_key] = hf_tokenizer _model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model _model_cache[mdl_cache_key] = hf_model
@@ -554,6 +516,7 @@ def compute_embeddings_sentence_transformers(
for start_index in batch_iter: for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts)) end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index] batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer( inputs = hf_tokenizer(
batch_texts, batch_texts,
padding=True, padding=True,
@@ -561,17 +524,34 @@ def compute_embeddings_sentence_transformers(
max_length=max_length, max_length=max_length,
return_tensors="pt", return_tensors="pt",
) )
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()} inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs) outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H) last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask") attention_mask = inputs.get("attention_mask")
if attention_mask is None: if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1) pooled = last_hidden_state.mean(dim=1)
else: else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1) lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy() batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings) all_embeddings.append(batch_embeddings)
@@ -591,12 +571,6 @@ def compute_embeddings_sentence_transformers(
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}") raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
outer_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers total time "
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
)
return embeddings return embeddings

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.3.4" version = "0.3.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"