Make DiskANN and HNSW work on main example (#2)
* fix: diskann zmq port and passages * feat: auto discovery of packages and fix passage gen for diskann
This commit is contained in:
@@ -7,6 +7,8 @@ import pickle
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
@@ -17,49 +19,49 @@ import numpy as np
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# 简化的文档存储 - 替代 LazyPassages
|
||||
class SimpleDocumentStore:
|
||||
"""简化的文档存储,支持任意ID"""
|
||||
def __init__(self, documents: dict = None):
|
||||
self.documents = documents or {}
|
||||
# 默认演示文档
|
||||
self.default_docs = {
|
||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
||||
1: "Machine learning builds systems that learn from data.",
|
||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
||||
}
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
class SimplePassageLoader:
|
||||
"""
|
||||
Simple passage loader that replaces config.py dependencies
|
||||
"""
|
||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||
self.passages_data = passages_data or {}
|
||||
|
||||
def __getitem__(self, doc_id):
|
||||
doc_id = int(doc_id)
|
||||
|
||||
# 优先使用指定的文档
|
||||
if doc_id in self.documents:
|
||||
return {"text": self.documents[doc_id]}
|
||||
|
||||
# 其次使用默认演示文档
|
||||
if doc_id in self.default_docs:
|
||||
return {"text": self.default_docs[doc_id]}
|
||||
|
||||
# 对于任意其他ID,返回通用文档
|
||||
fallback_docs = [
|
||||
"This is a general document about technology and programming concepts.",
|
||||
"This document discusses machine learning and artificial intelligence topics.",
|
||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
||||
"This is a document about software engineering and development practices.",
|
||||
"This content focuses on databases, data management, and information systems."
|
||||
]
|
||||
|
||||
# 根据ID选择一个fallback文档
|
||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID"""
|
||||
str_id = str(passage_id)
|
||||
if str_id in self.passages_data:
|
||||
return {"text": self.passages_data[str_id]}
|
||||
else:
|
||||
# Return empty text for missing passages
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.documents) + len(self.default_docs)
|
||||
def __len__(self) -> int:
|
||||
return len(self.passages_data)
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSON file
|
||||
Expected format: {"passage_id": "passage_text", ...}
|
||||
"""
|
||||
if not os.path.exists(passages_file):
|
||||
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
||||
return SimplePassageLoader()
|
||||
|
||||
try:
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
passages_data = json.load(f)
|
||||
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
except Exception as e:
|
||||
print(f"Error loading passages from {passages_file}: {e}")
|
||||
return SimplePassageLoader()
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
在当前线程中创建并运行 embedding server
|
||||
@@ -109,15 +111,14 @@ def create_embedding_server_thread(
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# 默认演示文档
|
||||
demo_documents = {
|
||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
||||
1: "Machine learning builds systems that learn from data.",
|
||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
||||
}
|
||||
# Load passages from file if provided
|
||||
if passages_file and os.path.exists(passages_file):
|
||||
passages = load_passages_from_file(passages_file)
|
||||
else:
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
passages = SimpleDocumentStore(demo_documents)
|
||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
||||
print(f"INFO: Loaded {len(passages)} passages.")
|
||||
|
||||
class DeviceTimer:
|
||||
"""设备计时器"""
|
||||
@@ -264,7 +265,13 @@ def create_embedding_server_thread(
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
txt = txtinfo["text"]
|
||||
texts.append(txt)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
else:
|
||||
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
|
||||
# 但将其ID记录为缺失
|
||||
texts.append("")
|
||||
missing_ids.append(nid)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
@@ -360,18 +367,20 @@ def create_embedding_server(
|
||||
max_batch_size=128,
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
@@ -394,4 +403,5 @@ if __name__ == "__main__":
|
||||
max_batch_size=args.max_batch_size,
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
)
|
||||
Reference in New Issue
Block a user