Make DiskANN and HNSW work on main example (#2)

* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann
This commit is contained in:
Andy Lee
2025-07-06 13:18:12 +08:00
committed by GitHub
parent a38bc0a3fc
commit cf17c85607
9 changed files with 149 additions and 9653 deletions

View File

@@ -7,6 +7,8 @@ import pickle
import argparse
import threading
import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel
import os
@@ -17,49 +19,49 @@ import numpy as np
RED = "\033[91m"
RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages
class SimpleDocumentStore:
"""简化的文档存储支持任意ID"""
def __init__(self, documents: dict = None):
self.documents = documents or {}
# 默认演示文档
self.default_docs = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
# --- New Passage Loader from HNSW backend ---
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
def __getitem__(self, doc_id):
doc_id = int(doc_id)
# 优先使用指定的文档
if doc_id in self.documents:
return {"text": self.documents[doc_id]}
# 其次使用默认演示文档
if doc_id in self.default_docs:
return {"text": self.default_docs[doc_id]}
# 对于任意其他ID返回通用文档
fallback_docs = [
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self):
return len(self.documents) + len(self.default_docs)
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSON file
Expected format: {"passage_id": "passage_text", ...}
"""
if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
return SimplePassageLoader()
try:
with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}")
return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
def create_embedding_server_thread(
zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
passages_file: Optional[str] = None,
):
"""
在当前线程中创建并运行 embedding server
@@ -109,15 +111,14 @@ def create_embedding_server_thread(
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档
demo_documents = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
passages = SimpleDocumentStore(demo_documents)
print(f"INFO: Loaded {len(passages)} demo documents")
print(f"INFO: Loaded {len(passages)} passages.")
class DeviceTimer:
"""设备计时器"""
@@ -264,7 +265,13 @@ def create_embedding_server_thread(
for nid in node_ids:
txtinfo = passages[nid]
txt = txtinfo["text"]
texts.append(txt)
if txt:
texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed()
if missing_ids:
@@ -360,18 +367,20 @@ def create_embedding_server(
max_batch_size=128,
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False)
@@ -394,4 +403,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
passages_file=args.passages_file,
)