Compare commits
13 Commits
embed-laun
...
fix/securi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf67a848f4 | ||
|
|
31ee48e3c8 | ||
|
|
3b23330bde | ||
|
|
1ad9f75e96 | ||
|
|
15d15f8881 | ||
|
|
d70e8fe40c | ||
|
|
e8c4ccde53 | ||
|
|
d27970538a | ||
|
|
dab299043d | ||
|
|
620da9dc27 | ||
|
|
27d5a49f94 | ||
|
|
043e32d959 | ||
|
|
3c4785bb63 |
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: lycheeverse/lychee-action@v2
|
- uses: lycheeverse/lychee-action@v2
|
||||||
with:
|
with:
|
||||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
4
.lycheeignore
Normal file
4
.lycheeignore
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# Exclude star-history API from link checking
|
||||||
|
# This service is intermittently unavailable (503 errors)
|
||||||
|
# but the link still works when the service is up
|
||||||
|
.*api\.star-history\.com.*
|
||||||
@@ -7,6 +7,7 @@ for indexing in LEANN. It supports various Slack MCP server implementations and
|
|||||||
flexible message processing options.
|
flexible message processing options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -146,16 +147,16 @@ class SlackMCPReader:
|
|||||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = eval(match.group(1))
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
except (ValueError, SyntaxError, NameError):
|
except (ValueError, SyntaxError):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Try alternative format
|
# Try alternative format
|
||||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
error_dict = eval(match.group(1))
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
except (ValueError, SyntaxError, NameError):
|
except (ValueError, SyntaxError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self._is_cache_sync_error(error_dict):
|
if self._is_cache_sync_error(error_dict):
|
||||||
|
|||||||
@@ -1,98 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script to reproduce issue #159: Slow search performance
|
|
||||||
Configuration:
|
|
||||||
- GPU: A10
|
|
||||||
- embedding_model: BAAI/bge-large-zh-v1.5
|
|
||||||
- data size: 180M text (~90K chunks)
|
|
||||||
- backend: hnsw
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
|
||||||
|
|
||||||
# Configuration matching the issue
|
|
||||||
INDEX_PATH = "./test_issue_159.leann"
|
|
||||||
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
|
||||||
BACKEND_NAME = "hnsw"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_test_data(num_chunks=90000, chunk_size=2000):
|
|
||||||
"""Generate test data similar to 180MB text (~90K chunks)"""
|
|
||||||
# Each chunk is approximately 2000 characters
|
|
||||||
# 90K chunks * 2000 chars ≈ 180MB
|
|
||||||
chunks = []
|
|
||||||
base_text = (
|
|
||||||
"这是一个测试文档。LEANN是一个创新的向量数据库, 通过图基选择性重计算实现97%的存储节省。"
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(num_chunks):
|
|
||||||
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
|
|
||||||
chunks.append(chunk[:chunk_size])
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
def test_search_performance():
|
|
||||||
"""Test search performance with different configurations"""
|
|
||||||
print("=" * 80)
|
|
||||||
print("Testing LEANN Search Performance (Issue #159)")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
meta_path = Path(f"{INDEX_PATH}.meta.json")
|
|
||||||
if meta_path.exists():
|
|
||||||
print(f"\n✓ Index already exists at {INDEX_PATH}")
|
|
||||||
print(" Skipping build phase. Delete the index to rebuild.")
|
|
||||||
else:
|
|
||||||
print("\n📦 Building index...")
|
|
||||||
print(f" Backend: {BACKEND_NAME}")
|
|
||||||
print(f" Embedding Model: {EMBEDDING_MODEL}")
|
|
||||||
print(" Generating test data (~90K chunks, ~180MB)...")
|
|
||||||
|
|
||||||
chunks = generate_test_data(num_chunks=90000)
|
|
||||||
print(f" Generated {len(chunks)} chunks")
|
|
||||||
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=BACKEND_NAME,
|
|
||||||
embedding_model=EMBEDDING_MODEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(" Adding chunks to builder...")
|
|
||||||
start_time = time.time()
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
builder.add_text(chunk)
|
|
||||||
if (i + 1) % 10000 == 0:
|
|
||||||
print(f" Added {i + 1}/{len(chunks)} chunks...")
|
|
||||||
|
|
||||||
print(" Building index...")
|
|
||||||
build_start = time.time()
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
build_time = time.time() - build_start
|
|
||||||
print(f" ✓ Index built in {build_time:.2f} seconds")
|
|
||||||
|
|
||||||
# Test search with different complexity values
|
|
||||||
print("\n🔍 Testing search performance...")
|
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
|
||||||
|
|
||||||
test_query = "LEANN向量数据库存储优化"
|
|
||||||
|
|
||||||
# Test with minimal complexity (8)
|
|
||||||
print("\n Test 4: Minimal complexity (8)")
|
|
||||||
print(f" Query: '{test_query}'")
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(test_query, top_k=10, complexity=8)
|
|
||||||
search_time = time.time() - start_time
|
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
|
||||||
print(f" Results: {len(results)} items")
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_search_performance()
|
|
||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.5", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
@@ -143,6 +143,8 @@ def create_hnsw_embedding_server(
|
|||||||
pass
|
pass
|
||||||
return str(nid)
|
return str(nid)
|
||||||
|
|
||||||
|
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
@@ -156,238 +158,225 @@ def create_hnsw_embedding_server(
|
|||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
last_request_type = "unknown"
|
# Track last request type/length for shape-correct fallbacks
|
||||||
|
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||||
last_request_length = 0
|
last_request_length = 0
|
||||||
|
|
||||||
def _build_safe_fallback():
|
|
||||||
if last_request_type == "distance":
|
|
||||||
large_distance = 1e9
|
|
||||||
fallback_len = max(0, int(last_request_length))
|
|
||||||
return [[large_distance] * fallback_len]
|
|
||||||
if last_request_type == "embedding":
|
|
||||||
bsz = max(0, int(last_request_length))
|
|
||||||
dim = max(0, int(embedding_dim))
|
|
||||||
if dim > 0:
|
|
||||||
return [[bsz, dim], [0.0] * (bsz * dim)]
|
|
||||||
return [[0, 0], []]
|
|
||||||
if last_request_type == "text":
|
|
||||||
return []
|
|
||||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
|
||||||
|
|
||||||
def _handle_text_embedding(request: list[str]) -> None:
|
|
||||||
nonlocal last_request_type, last_request_length
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
last_request_type = "text"
|
|
||||||
last_request_length = len(request)
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
request,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
def _handle_distance_request(request: list[Any]) -> None:
|
|
||||||
nonlocal last_request_type, last_request_length
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
node_ids = request[0]
|
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
|
||||||
node_ids = node_ids[0]
|
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
|
||||||
last_request_type = "distance"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
|
||||||
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_id = _map_node_id(nid)
|
|
||||||
passage_data = passages.get_passage(passage_id)
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage ID {nid} not found")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
|
||||||
|
|
||||||
large_distance = 1e9
|
|
||||||
response_distances = [large_distance] * len(node_ids)
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
if distance_metric == "l2":
|
|
||||||
partial = np.sum(
|
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
partial = -np.dot(embeddings, query_vector)
|
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
|
||||||
response_distances[pos] = float(dval)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
|
||||||
|
|
||||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
def _handle_embedding_by_id(request: Any) -> None:
|
|
||||||
nonlocal last_request_type, last_request_length
|
|
||||||
|
|
||||||
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
|
|
||||||
node_ids = request[0]
|
|
||||||
elif isinstance(request, list):
|
|
||||||
node_ids = request
|
|
||||||
else:
|
|
||||||
node_ids = []
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
last_request_type = "embedding"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
|
||||||
|
|
||||||
if embedding_dim <= 0:
|
|
||||||
dims = [0, 0]
|
|
||||||
flat_data: list[float] = []
|
|
||||||
else:
|
|
||||||
dims = [len(node_ids), embedding_dim]
|
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
|
||||||
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_id = _map_node_id(nid)
|
|
||||||
passage_data = passages.get_passage(passage_id)
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage with ID {nid} not found")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
|
||||||
logger.error(
|
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
|
||||||
)
|
|
||||||
dims = [0, embedding_dim]
|
|
||||||
flat_data = []
|
|
||||||
else:
|
|
||||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
flat = emb_f32.flatten().tolist()
|
|
||||||
for j, pos in enumerate(found_indices):
|
|
||||||
start = pos * embedding_dim
|
|
||||||
end = start + embedding_dim
|
|
||||||
if end <= len(flat_data):
|
|
||||||
flat_data[start:end] = flat[
|
|
||||||
j * embedding_dim : (j + 1) * embedding_dim
|
|
||||||
]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
|
||||||
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
|
e2e_start = time.time()
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
request_bytes = rep_socket.recv()
|
request_bytes = rep_socket.recv()
|
||||||
except zmq.Again:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
# Rest of the processing logic (same as original)
|
||||||
request = msgpack.unpackb(request_bytes)
|
request = msgpack.unpackb(request_bytes)
|
||||||
except Exception as exc:
|
|
||||||
if shutdown_event.is_set():
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
logger.error(f"Error unpacking ZMQ message: {exc}")
|
|
||||||
try:
|
|
||||||
safe = _build_safe_fallback()
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
# Model query
|
response_bytes = msgpack.packb([model_name])
|
||||||
|
rep_socket.send(response_bytes)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle direct text embedding request
|
||||||
if (
|
if (
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and request[0] == "__QUERY_MODEL__"
|
|
||||||
):
|
|
||||||
rep_socket.send(msgpack.packb([model_name]))
|
|
||||||
# Direct text embedding
|
|
||||||
elif (
|
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and request
|
and request
|
||||||
and all(isinstance(item, str) for item in request)
|
and all(isinstance(item, str) for item in request)
|
||||||
):
|
):
|
||||||
_handle_text_embedding(request)
|
last_request_type = "text"
|
||||||
# Distance calculation: [[ids], [query_vector]]
|
last_request_length = len(request)
|
||||||
elif (
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle distance calculation request: [[ids], [query_vector]]
|
||||||
|
if (
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and len(request) == 2
|
and len(request) == 2
|
||||||
and isinstance(request[0], list)
|
and isinstance(request[0], list)
|
||||||
and isinstance(request[1], list)
|
and isinstance(request[1], list)
|
||||||
):
|
):
|
||||||
_handle_distance_request(request)
|
node_ids = request[0]
|
||||||
# Embedding-by-id fallback
|
# Handle nested [[ids]] shape defensively
|
||||||
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
|
node_ids = node_ids[0]
|
||||||
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
|
last_request_type = "distance"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
|
||||||
|
logger.debug("Distance calculation request received")
|
||||||
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
|
# Gather texts for found ids
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
|
||||||
|
# Prepare full-length response with large sentinel values
|
||||||
|
large_distance = 1e9
|
||||||
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
if distance_metric == "l2":
|
||||||
|
partial = np.sum(
|
||||||
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
|
)
|
||||||
|
else: # mips or cosine
|
||||||
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
|
response_distances[pos] = float(dval)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||||
|
|
||||||
|
# Send response in expected shape [[distances]]
|
||||||
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fallback: treat as embedding-by-id request
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
):
|
||||||
|
node_ids = request[0]
|
||||||
|
elif isinstance(request, list):
|
||||||
|
node_ids = request
|
||||||
|
else:
|
||||||
|
node_ids = []
|
||||||
|
last_request_type = "embedding"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
|
# Preallocate zero-filled flat data for robustness
|
||||||
|
if embedding_dim <= 0:
|
||||||
|
dims = [0, 0]
|
||||||
|
flat_data: list[float] = []
|
||||||
|
else:
|
||||||
|
dims = [len(node_ids), embedding_dim]
|
||||||
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
|
# Collect texts for found ids
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error(
|
||||||
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
|
)
|
||||||
|
dims = [0, embedding_dim]
|
||||||
|
flat_data = []
|
||||||
|
else:
|
||||||
|
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
flat = emb_f32.flatten().tolist()
|
||||||
|
for j, pos in enumerate(found_indices):
|
||||||
|
start = pos * embedding_dim
|
||||||
|
end = start + embedding_dim
|
||||||
|
if end <= len(flat_data):
|
||||||
|
flat_data[start:end] = flat[
|
||||||
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||||
|
|
||||||
|
response_payload = [dims, flat_data]
|
||||||
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
|
||||||
|
rep_socket.send(response_bytes)
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
# Shape-correct fallback
|
||||||
|
try:
|
||||||
|
if last_request_type == "distance":
|
||||||
|
large_distance = 1e9
|
||||||
|
fallback_len = max(0, int(last_request_length))
|
||||||
|
safe = [[large_distance] * fallback_len]
|
||||||
|
elif last_request_type == "embedding":
|
||||||
|
bsz = max(0, int(last_request_length))
|
||||||
|
dim = max(0, int(embedding_dim))
|
||||||
|
safe = (
|
||||||
|
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||||
|
)
|
||||||
|
elif last_request_type == "text":
|
||||||
|
safe = [] # direct text embeddings expectation is a flat list
|
||||||
|
else:
|
||||||
|
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
_handle_embedding_by_id(request)
|
|
||||||
except Exception as exc:
|
|
||||||
if shutdown_event.is_set():
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
break
|
break
|
||||||
logger.error(f"Error in ZMQ server loop: {exc}")
|
|
||||||
try:
|
|
||||||
safe = _build_safe_fallback()
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
rep_socket.close(0)
|
rep_socket.close(0)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.4",
|
"leann-core==0.3.5",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 301bf24f14...e2d243c40d
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -864,13 +864,7 @@ class LeannBuilder:
|
|||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
self,
|
|
||||||
index_path: str,
|
|
||||||
enable_warmup: bool = True,
|
|
||||||
recompute_embeddings: bool = True,
|
|
||||||
**backend_kwargs,
|
|
||||||
):
|
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution for Colab and other environments
|
||||||
if not Path(index_path).is_absolute():
|
if not Path(index_path).is_absolute():
|
||||||
index_path = str(Path(index_path).resolve())
|
index_path = str(Path(index_path).resolve())
|
||||||
@@ -901,32 +895,14 @@ class LeannSearcher:
|
|||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
|
|
||||||
# Global recompute flag for this searcher (explicit knob, default True)
|
|
||||||
self.recompute_embeddings: bool = bool(recompute_embeddings)
|
|
||||||
|
|
||||||
# Warmup flag: keep using the existing enable_warmup parameter,
|
|
||||||
# but default it to True so cold-start happens earlier.
|
|
||||||
self._warmup: bool = bool(enable_warmup)
|
|
||||||
|
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = self._warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
if self.embedding_options:
|
if self.embedding_options:
|
||||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional one-shot warmup at construction time to hide cold-start latency.
|
|
||||||
if self._warmup:
|
|
||||||
try:
|
|
||||||
_ = self.backend_impl.compute_query_embedding(
|
|
||||||
"__LEANN_WARMUP__",
|
|
||||||
use_server_if_available=self.recompute_embeddings,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"Warmup embedding failed (ignored): {exc}")
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -934,7 +910,7 @@ class LeannSearcher:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: Optional[bool] = None,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
@@ -951,8 +927,7 @@ class LeannSearcher:
|
|||||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
beam_width: Number of parallel search paths/IO requests per iteration
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||||
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
|
|
||||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
@@ -991,19 +966,8 @@ class LeannSearcher:
|
|||||||
|
|
||||||
zmq_port = None
|
zmq_port = None
|
||||||
|
|
||||||
# Resolve effective recompute flag for this search.
|
|
||||||
if recompute_embeddings is not None:
|
|
||||||
logger.warning(
|
|
||||||
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
|
|
||||||
"will be removed in a future version. Configure recompute at "
|
|
||||||
"LeannSearcher(..., recompute_embeddings=...) instead."
|
|
||||||
)
|
|
||||||
effective_recompute = bool(recompute_embeddings)
|
|
||||||
else:
|
|
||||||
effective_recompute = self.recompute_embeddings
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if effective_recompute:
|
if recompute_embeddings:
|
||||||
zmq_port = self.backend_impl._ensure_server_running(
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
self.meta_path_str,
|
self.meta_path_str,
|
||||||
port=expected_zmq_port,
|
port=expected_zmq_port,
|
||||||
@@ -1017,7 +981,7 @@ class LeannSearcher:
|
|||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=effective_recompute,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
@@ -1029,7 +993,7 @@ class LeannSearcher:
|
|||||||
"complexity": complexity,
|
"complexity": complexity,
|
||||||
"beam_width": beam_width,
|
"beam_width": beam_width,
|
||||||
"prune_ratio": prune_ratio,
|
"prune_ratio": prune_ratio,
|
||||||
"recompute_embeddings": effective_recompute,
|
"recompute_embeddings": recompute_embeddings,
|
||||||
"pruning_strategy": pruning_strategy,
|
"pruning_strategy": pruning_strategy,
|
||||||
"zmq_port": zmq_port,
|
"zmq_port": zmq_port,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -215,14 +215,9 @@ def compute_embeddings(
|
|||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
provider_options = provider_options or {}
|
provider_options = provider_options or {}
|
||||||
wrapper_start_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
inner_start_time = time.time()
|
return compute_embeddings_sentence_transformers(
|
||||||
result = compute_embeddings_sentence_transformers(
|
|
||||||
texts,
|
texts,
|
||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
@@ -231,14 +226,6 @@ def compute_embeddings(
|
|||||||
manual_tokenize=manual_tokenize,
|
manual_tokenize=manual_tokenize,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
inner_end_time = time.time()
|
|
||||||
wrapper_end_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
"[compute_embeddings] sentence-transformers timings: "
|
|
||||||
f"inner={inner_end_time - inner_start_time:.6f}s, "
|
|
||||||
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(
|
return compute_embeddings_openai(
|
||||||
texts,
|
texts,
|
||||||
@@ -284,7 +271,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
"""
|
"""
|
||||||
outer_start_time = time.time()
|
|
||||||
# Handle empty input
|
# Handle empty input
|
||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
@@ -315,14 +301,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
# Create cache key
|
# Create cache key
|
||||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||||
|
|
||||||
pre_model_init_end_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
"compute_embeddings_sentence_transformers pre-model-init time "
|
|
||||||
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if model is already cached
|
# Check if model is already cached
|
||||||
start_time = time.time()
|
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
logger.info(f"Using cached optimized model: {model_name}")
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
model = _model_cache[cache_key]
|
model = _model_cache[cache_key]
|
||||||
@@ -462,13 +441,10 @@ def compute_embeddings_sentence_transformers(
|
|||||||
_model_cache[cache_key] = model
|
_model_cache[cache_key] = model
|
||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
end_time = time.time()
|
# Compute embeddings with optimized inference mode
|
||||||
|
logger.info(
|
||||||
# Compute embeddings with optimized inference mode
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
logger.info(
|
)
|
||||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
|
||||||
)
|
|
||||||
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if not manual_tokenize:
|
if not manual_tokenize:
|
||||||
@@ -489,46 +465,32 @@ def compute_embeddings_sentence_transformers(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
# This path is reserved for an aggressively optimized FP pipeline
|
|
||||||
# (no quantization), mainly for experimentation.
|
|
||||||
try:
|
try:
|
||||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
|
||||||
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
hf_tokenizer = _model_cache[tok_cache_key]
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
hf_model = _model_cache[mdl_cache_key]
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
logger.info("Using cached HF tokenizer/model for manual FP path")
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
else:
|
else:
|
||||||
logger.info("Loading HF tokenizer/model for manual FP path")
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
hf_model = AutoModel.from_pretrained(
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
model_name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
hf_model.to(device)
|
hf_model.to(device)
|
||||||
|
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
# Optional compile on supported devices
|
# Optional compile on supported devices
|
||||||
if device in ["cuda", "mps"]:
|
if device in ["cuda", "mps"]:
|
||||||
try:
|
try:
|
||||||
hf_model = torch.compile( # type: ignore
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
hf_model, mode="reduce-overhead", dynamic=True
|
except Exception:
|
||||||
)
|
pass
|
||||||
logger.info(
|
|
||||||
f"Applied torch.compile to HF model for {model_name} "
|
|
||||||
f"(device={device}, dtype={torch_dtype})"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"torch.compile optimization failed: {exc}")
|
|
||||||
|
|
||||||
_model_cache[tok_cache_key] = hf_tokenizer
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
_model_cache[mdl_cache_key] = hf_model
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
@@ -554,6 +516,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
for start_index in batch_iter:
|
for start_index in batch_iter:
|
||||||
end_index = min(start_index + batch_size, len(texts))
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
batch_texts = texts[start_index:end_index]
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
inputs = hf_tokenizer(
|
inputs = hf_tokenizer(
|
||||||
batch_texts,
|
batch_texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -561,17 +524,34 @@ def compute_embeddings_sentence_transformers(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
outputs = hf_model(**inputs)
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
attention_mask = inputs.get("attention_mask")
|
attention_mask = inputs.get("attention_mask")
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
pooled = last_hidden_state.mean(dim=1)
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
else:
|
else:
|
||||||
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
masked = last_hidden_state * mask
|
masked = last_hidden_state * mask
|
||||||
lengths = mask.sum(dim=1).clamp(min=1)
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
pooled = masked.sum(dim=1) / lengths
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
all_embeddings.append(batch_embeddings)
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
@@ -591,12 +571,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||||
|
|
||||||
outer_end_time = time.time()
|
|
||||||
logger.debug(
|
|
||||||
"compute_embeddings_sentence_transformers total time "
|
|
||||||
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
Reference in New Issue
Block a user