perf: make embedder loading faster by 6x, and embed queries through the server

This commit is contained in:
Andy Lee
2025-07-17 20:08:06 -07:00
parent 99d439577d
commit 1c5fec5565
4 changed files with 323 additions and 105 deletions

View File

@@ -16,8 +16,17 @@ import zmq
import numpy as np
import msgpack
from pathlib import Path
import logging
RED = "\033[91m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
RESET = "\033[0m"
# --- New Passage Loader from HNSW backend ---
@@ -169,7 +178,7 @@ def create_embedding_server_thread(
在当前线程中创建并运行 embedding server
这个函数设计为在单独的线程中调用
"""
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
logger.info(f"Initializing embedding server thread on port {zmq_port}")
try:
# 检查端口是否已被占用
@@ -189,7 +198,7 @@ def create_embedding_server_thread(
if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx
import torch
print("INFO: Using MLX for embeddings")
logger.info("Using MLX for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
@@ -197,7 +206,7 @@ def create_embedding_server_thread(
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
import torch
print("INFO: Using OpenAI API for embeddings")
logger.info("Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
@@ -213,16 +222,16 @@ def create_embedding_server_thread(
if cuda_available:
device = torch.device("cuda")
print("INFO: Using CUDA device")
logger.info("Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)")
logger.info("Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
print("INFO: Using CPU device")
logger.info("Using CPU device")
# 加载模型
print(f"INFO: Loading model {model_name}")
logger.info(f"Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# 优化模型
@@ -230,7 +239,7 @@ def create_embedding_server_thread(
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}")
logger.info(f"Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
else:
@@ -256,7 +265,7 @@ def create_embedding_server_thread(
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
print(f"INFO: Loaded {len(passages)} passages.")
logger.info(f"Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
@@ -365,7 +374,7 @@ def create_embedding_server_thread(
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""处理文本批次"""
batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}")
logger.info(f"Processing batch of size {batch_size}")
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)