fix: diskann zmq port and passages

This commit is contained in:
Andy Lee
2025-07-06 04:14:15 +00:00
parent a38bc0a3fc
commit 5659174635
6 changed files with 71 additions and 9638 deletions

3
.gitignore vendored
View File

@@ -71,3 +71,6 @@ test_indices*/
test_*.py test_*.py
!tests/** !tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Any
import contextlib import contextlib
import threading import threading
import time import time
@@ -36,7 +36,7 @@ def chdir(path):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str): def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape num_vectors, dim = data.shape
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors)) f.write(struct.pack('I', num_vectors))
@@ -146,11 +146,12 @@ class DiskannSearcher(LeannBackendSearcherInterface):
num_threads = kwargs.get("num_threads", 8) num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
zmq_port = kwargs.get("zmq_port", 5555) # Get zmq_port from kwargs
try: try:
full_index_prefix = str(index_dir / index_prefix) full_index_prefix = str(index_dir / index_prefix)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", "" metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, zmq_port, "", ""
) )
self.num_threads = num_threads self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
@@ -161,7 +162,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}") print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256) complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4) beam_width = kwargs.get("beam_width", 4)

View File

@@ -7,6 +7,8 @@ import pickle
import argparse import argparse
import threading import threading
import time import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import os import os
@@ -17,49 +19,49 @@ import numpy as np
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages # --- New Passage Loader from HNSW backend ---
class SimpleDocumentStore: class SimplePassageLoader:
"""简化的文档存储支持任意ID""" """
def __init__(self, documents: dict = None): Simple passage loader that replaces config.py dependencies
self.documents = documents or {} """
# 默认演示文档 def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.default_docs = { self.passages_data = passages_data or {}
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
def __getitem__(self, doc_id): def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
doc_id = int(doc_id) """Get passage by ID"""
str_id = str(passage_id)
# 优先使用指定的文档 if str_id in self.passages_data:
if doc_id in self.documents: return {"text": self.passages_data[str_id]}
return {"text": self.documents[doc_id]} else:
# Return empty text for missing passages
# 其次使用默认演示文档 return {"text": ""}
if doc_id in self.default_docs:
return {"text": self.default_docs[doc_id]}
# 对于任意其他ID返回通用文档
fallback_docs = [
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __len__(self): def __len__(self) -> int:
return len(self.documents) + len(self.default_docs) return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSON file
Expected format: {"passage_id": "passage_text", ...}
"""
if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
return SimplePassageLoader()
try:
with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}")
return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
def create_embedding_server_thread( def create_embedding_server_thread(
zmq_port=5555, zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128, max_batch_size=128,
passages_file: Optional[str] = None,
): ):
""" """
在当前线程中创建并运行 embedding server 在当前线程中创建并运行 embedding server
@@ -109,15 +111,14 @@ def create_embedding_server_thread(
except Exception as e: except Exception as e:
print(f"WARNING: Model optimization failed: {e}") print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档 # Load passages from file if provided
demo_documents = { if passages_file and os.path.exists(passages_file):
0: "Python is a high-level, interpreted language known for simplicity.", passages = load_passages_from_file(passages_file)
1: "Machine learning builds systems that learn from data.", else:
2: "Data structures like arrays, lists, and graphs organize data.", print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
} passages = SimplePassageLoader()
passages = SimpleDocumentStore(demo_documents) print(f"INFO: Loaded {len(passages)} passages.")
print(f"INFO: Loaded {len(passages)} demo documents")
class DeviceTimer: class DeviceTimer:
"""设备计时器""" """设备计时器"""
@@ -264,7 +265,13 @@ def create_embedding_server_thread(
for nid in node_ids: for nid in node_ids:
txtinfo = passages[nid] txtinfo = passages[nid]
txt = txtinfo["text"] txt = txtinfo["text"]
texts.append(txt) if txt:
texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed() lookup_timer.print_elapsed()
if missing_ids: if missing_ids:
@@ -360,18 +367,20 @@ def create_embedding_server(
max_batch_size=128, max_batch_size=128,
lazy_load_passages=False, lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
): ):
""" """
原有的 create_embedding_server 函数保持不变 原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行 这个是阻塞版本,用于直接运行
""" """
create_embedding_server_thread(zmq_port, model_name, max_batch_size) create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service") parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name") parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True) parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False) parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-fp16", action="store_true", default=False)
@@ -394,4 +403,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages, lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name, model_name=args.model_name,
passages_file=args.passages_file,
) )

View File

@@ -303,7 +303,8 @@ class HNSWSearcher(LeannBackendSearcherInterface):
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0) hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path") hnsw_config.external_storage_path = kwargs.get("external_storage_path")
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute: if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
@@ -361,13 +362,15 @@ class HNSWSearcher(LeannBackendSearcherInterface):
faiss.normalize_L2(query) faiss.normalize_L2(query)
try: try:
self._index.hnsw.efSearch = ef params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
batch_size = query.shape[0] batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32) distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64) labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels)) self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances} return {"labels": labels, "distances": distances}

View File

File diff suppressed because it is too large Load Diff