perf: make embedder loading faster by 6x, and embed queries through the server
This commit is contained in:
@@ -16,8 +16,17 @@ import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
@@ -169,7 +178,7 @@ def create_embedding_server_thread(
|
||||
在当前线程中创建并运行 embedding server
|
||||
这个函数设计为在单独的线程中调用
|
||||
"""
|
||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# 检查端口是否已被占用
|
||||
@@ -189,7 +198,7 @@ def create_embedding_server_thread(
|
||||
if embedding_mode == "mlx":
|
||||
from leann.api import compute_embeddings_mlx
|
||||
import torch
|
||||
print("INFO: Using MLX for embeddings")
|
||||
logger.info("Using MLX for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
@@ -197,7 +206,7 @@ def create_embedding_server_thread(
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
import torch
|
||||
print("INFO: Using OpenAI API for embeddings")
|
||||
logger.info("Using OpenAI API for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
@@ -213,16 +222,16 @@ def create_embedding_server_thread(
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
print("INFO: Using CUDA device")
|
||||
logger.info("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
print("INFO: Using MPS device (Apple Silicon)")
|
||||
logger.info("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("INFO: Using CPU device")
|
||||
logger.info("Using CPU device")
|
||||
|
||||
# 加载模型
|
||||
print(f"INFO: Loading model {model_name}")
|
||||
logger.info(f"Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# 优化模型
|
||||
@@ -230,7 +239,7 @@ def create_embedding_server_thread(
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
else:
|
||||
@@ -256,7 +265,7 @@ def create_embedding_server_thread(
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
print(f"INFO: Loaded {len(passages)} passages.")
|
||||
logger.info(f"Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
@@ -365,7 +374,7 @@ def create_embedding_server_thread(
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""处理文本批次"""
|
||||
batch_size = len(texts_batch)
|
||||
print(f"INFO: Processing batch of size {batch_size}")
|
||||
logger.info(f"Processing batch of size {batch_size}")
|
||||
|
||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
||||
|
||||
Reference in New Issue
Block a user