Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG

This commit is contained in:
yichuan520030910320
2025-07-15 22:29:09 -07:00
5 changed files with 105 additions and 8 deletions

View File

@@ -163,6 +163,7 @@ def create_embedding_server_thread(
max_batch_size=128,
passages_file: Optional[str] = None,
use_mlx: bool = False,
enable_warmup: bool = False,
):
"""
在当前线程中创建并运行 embedding server
@@ -238,6 +239,62 @@ def create_embedding_server_thread(
print(f"INFO: Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
return
# Use protobuf format for warmup
from . import embedding_pb2
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.node_ids.extend(ids_to_send)
request_bytes = req_proto.SerializeToString()
for i in range(3):
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
socket.send(request_bytes)
response_bytes = socket.recv()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
resp_proto.ParseFromString(response_bytes)
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side Protobuf ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during Protobuf ZMQ warmup: {e}")
class DeviceTimer:
"""设备计时器"""
def __init__(self, name="", device=device):
@@ -343,6 +400,16 @@ def create_embedding_server_thread(
print(f"INFO: Embedding server ready to serve requests")
# Start warmup thread if enabled
if enable_warmup and len(passages) > 0:
import threading
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
while True:
try:
parts = socket.recv_multipart()
@@ -587,12 +654,13 @@ def create_embedding_server(
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
use_mlx: bool = False,
enable_warmup: bool = False,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx)
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup)
if __name__ == "__main__":
@@ -610,6 +678,7 @@ if __name__ == "__main__":
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
args = parser.parse_args()
create_embedding_server(
@@ -625,4 +694,5 @@ if __name__ == "__main__":
model_name=args.model_name,
passages_file=args.passages_file,
use_mlx=args.use_mlx,
enable_warmup=not args.disable_warmup,
)

View File

@@ -151,6 +151,7 @@ def create_hnsw_embedding_server(
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
use_mlx: bool = False,
enable_warmup: bool = False,
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
@@ -167,6 +168,7 @@ def create_hnsw_embedding_server(
model_name: Transformer model name
custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
enable_warmup: Whether to perform warmup requests on server start
"""
if not use_mlx:
print(f"Loading tokenizer for {model_name}...")
@@ -465,7 +467,17 @@ def create_hnsw_embedding_server(
"""Perform client-side warmup"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
sample_ids = ["0", "1", "2", "3", "4"]
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
@@ -477,7 +489,8 @@ def create_hnsw_embedding_server(
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
ids_to_send = []
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
@@ -915,10 +928,13 @@ def create_hnsw_embedding_server(
pass
# Start warmup and server threads
if len(passages) > 0:
if enable_warmup and len(passages) > 0:
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
@@ -976,6 +992,12 @@ if __name__ == "__main__":
default=False,
help="Use MLX for model inference",
)
parser.add_argument(
"--disable-warmup",
action="store_true",
default=False,
help="Disable warmup requests on server start",
)
args = parser.parse_args()
@@ -992,4 +1014,5 @@ if __name__ == "__main__":
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
use_mlx=args.use_mlx,
enable_warmup=not args.disable_warmup,
)

View File

@@ -237,7 +237,7 @@ class LeannBuilder:
class LeannSearcher:
def __init__(self, index_path: str, **backend_kwargs):
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
@@ -251,6 +251,7 @@ class LeannSearcher:
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
@@ -306,9 +307,9 @@ from .chat import get_llm
class LeannChat:
def __init__(
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, enable_warmup: bool = False, **kwargs
):
self.searcher = LeannSearcher(index_path, **kwargs)
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs):

View File

@@ -184,7 +184,7 @@ class EmbeddingServerManager:
Args:
port (int): The ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
Returns:
bool: True if the server is started successfully or already running, False otherwise.
@@ -312,6 +312,8 @@ class EmbeddingServerManager:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
if "use_mlx" in kwargs and kwargs["use_mlx"]:
command.extend(["--use-mlx"])
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
command.extend(["--disable-warmup"])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")

View File

@@ -79,6 +79,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
use_mlx=kwargs.get("use_mlx", False),
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")