Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG

This commit is contained in:
yichuan520030910320
2025-07-17 22:29:39 -07:00
5 changed files with 348 additions and 110 deletions

View File

@@ -16,8 +16,17 @@ import zmq
import numpy as np
import msgpack
from pathlib import Path
import logging
RED = "\033[91m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
RESET = "\033[0m"
# --- New Passage Loader from HNSW backend ---
@@ -169,7 +178,7 @@ def create_embedding_server_thread(
在当前线程中创建并运行 embedding server
这个函数设计为在单独的线程中调用
"""
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
logger.info(f"Initializing embedding server thread on port {zmq_port}")
try:
# 检查端口是否已被占用
@@ -189,7 +198,7 @@ def create_embedding_server_thread(
if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx
import torch
print("INFO: Using MLX for embeddings")
logger.info("Using MLX for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
@@ -197,7 +206,7 @@ def create_embedding_server_thread(
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
import torch
print("INFO: Using OpenAI API for embeddings")
logger.info("Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
@@ -213,16 +222,16 @@ def create_embedding_server_thread(
if cuda_available:
device = torch.device("cuda")
print("INFO: Using CUDA device")
logger.info("Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)")
logger.info("Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
print("INFO: Using CPU device")
logger.info("Using CPU device")
# 加载模型
print(f"INFO: Loading model {model_name}")
logger.info(f"Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# 优化模型
@@ -230,7 +239,7 @@ def create_embedding_server_thread(
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}")
logger.info(f"Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
else:
@@ -256,7 +265,7 @@ def create_embedding_server_thread(
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
print(f"INFO: Loaded {len(passages)} passages.")
logger.info(f"Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
@@ -365,7 +374,7 @@ def create_embedding_server_thread(
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""处理文本批次"""
batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}")
logger.info(f"Processing batch of size {batch_size}")
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)

View File

@@ -141,6 +141,14 @@ class HNSWSearcher(BaseSearcher):
raise RuntimeError("Index is pruned but recompute is disabled.")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load label mapping
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
with open(label_map_file, "rb") as f:
self.label_map = pickle.load(f)
def search(
self,

View File

@@ -18,10 +18,19 @@ import json
from pathlib import Path
from typing import Dict, Any, Optional, Union
import sys
import logging
RED = "\033[91m"
RESET = "\033[0m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def is_similarity_metric():
"""
@@ -36,6 +45,23 @@ import torch
from torch import Tensor
import torch.nn.functional as F
# Timing utilities
@contextmanager
def timer(name: str, sync_cuda: bool = True):
"""Context manager for timing operations with optional CUDA sync"""
start_time = time.time()
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
try:
yield
finally:
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
elif sync_cuda and torch.backends.mps.is_available():
torch.mps.synchronize()
elapsed = time.time() - start_time
logger.info(f"⏱️ {name}: {elapsed:.4f}s")
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
@@ -120,13 +146,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
else:
print(f"DEBUG: ID {int_id} not found in label_map")
logger.debug(f"ID {int_id} not found in label_map")
return {"text": ""}
except Exception as e:
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
logger.debug(f"Exception getting passage {passage_id}: {e}")
return {"text": ""}
def __len__(self) -> int:
@@ -184,8 +210,21 @@ def create_hnsw_embedding_server(
tokenizer = None # MLX handles tokenization separately
else: # sentence-transformers
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
# Optimized tokenizer loading: try local first, then fallback
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True, # Use fast tokenizer (better runtime perf)
local_files_only=True # Avoid network delays
)
print(f"Tokenizer loaded successfully! (local + fast)")
except Exception as e:
print(f"Local tokenizer failed ({e}), trying network download...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True # Use fast tokenizer
)
print(f"Tokenizer loaded successfully! (network)")
# Device setup
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -217,9 +256,47 @@ def create_hnsw_embedding_server(
print("OpenAI API mode - no local model loading required")
model = None
else:
# Use standard transformers for sentence-transformers models
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully!")
# Use optimized transformers loading for sentence-transformers models
print(f"Loading model with optimizations...")
try:
# Ultra-fast loading: preload config + fast_init
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name, local_files_only=True)
model = AutoModel.from_pretrained(
model_name,
config=config,
torch_dtype=torch.float16, # Half precision for speed
low_cpu_mem_usage=True, # Reduce memory peaks
local_files_only=True, # Avoid network delays
_fast_init=True # Skip weight init checks
).to(device).eval()
print(f"Model {model_name} loaded successfully! (ultra-fast)")
except Exception as e:
print(f"Ultra-fast loading failed ({e}), trying optimized...")
try:
# Fallback: regular optimized loading
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True
).to(device).eval()
print(f"Model {model_name} loaded successfully! (optimized)")
except Exception as e2:
print(f"Optimized loading failed ({e2}), trying network...")
try:
# Fallback: optimized network loading
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to(device).eval()
print(f"Model {model_name} loaded successfully! (network + optimized)")
except Exception as e3:
print(f"All optimized methods failed ({e3}), using standard...")
# Final fallback: standard loading
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully! (standard)")
# Check port availability
import socket
@@ -370,8 +447,9 @@ def create_hnsw_embedding_server(
if embedding_mode == "mlx":
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
return compute_embeddings_openai(texts_batch, model_name)
with timer("OpenAI API call", sync_cuda=False):
from leann.api import compute_embeddings_openai
return compute_embeddings_openai(texts_batch, model_name)
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
@@ -417,44 +495,46 @@ def create_hnsw_embedding_server(
enc = {k: v.to(device) for k, v in encoded_batch.items()}
with torch.no_grad():
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
with timer("Model forward pass"):
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing():
if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, "last_hidden_state"):
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
with timer("Pooling"):
with pool_timer.timing():
if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, "last_hidden_state"):
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
else:
print(
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
)
hidden_dim = getattr(
model.config, "hidden_size", 384 if _is_e5_model else 768
)
pooled_embeddings = torch.zeros(
(batch_size, hidden_dim),
device=device,
dtype=enc["input_ids"].dtype
if hasattr(enc["input_ids"], "dtype")
else torch.float32,
)
elif _is_e5_model:
pooled_embeddings = e5_average_pool(
out.last_hidden_state, enc["attention_mask"]
)
else:
print(
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
hidden_states = out.last_hidden_state
mask_expanded = (
enc["attention_mask"]
.unsqueeze(-1)
.expand(hidden_states.size())
.float()
)
hidden_dim = getattr(
model.config, "hidden_size", 384 if _is_e5_model else 768
)
pooled_embeddings = torch.zeros(
(batch_size, hidden_dim),
device=device,
dtype=enc["input_ids"].dtype
if hasattr(enc["input_ids"], "dtype")
else torch.float32,
)
elif _is_e5_model:
pooled_embeddings = e5_average_pool(
out.last_hidden_state, enc["attention_mask"]
)
else:
hidden_states = out.last_hidden_state
mask_expanded = (
enc["attention_mask"]
.unsqueeze(-1)
.expand(hidden_states.size())
.float()
)
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings
if _is_e5_model or _is_bge_model:
@@ -536,7 +616,7 @@ def create_hnsw_embedding_server(
def zmq_server_thread():
"""ZMQ server thread"""
nonlocal passages, model, tokenizer, model_name
nonlocal passages, model, tokenizer, model_name, embedding_mode
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
@@ -556,13 +636,13 @@ def create_hnsw_embedding_server(
try:
request_payload = msgpack.unpackb(message_bytes)
if isinstance(request_payload, list):
print(f"DEBUG: request_payload length: {len(request_payload)}")
logger.debug(f"request_payload length: {len(request_payload)}")
for i, item in enumerate(request_payload):
print(
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
)
# Handle control messages for meta path and model management
# Handle control messages for meta path and model management FIRST
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
@@ -617,19 +697,61 @@ def create_hnsw_embedding_server(
)
# Clean up old model to free memory
print("INFO: Releasing old model from memory...")
logger.info("Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
# Load new tokenizer first (optimized)
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(
new_model_name, use_fast=True
)
try:
tokenizer = AutoTokenizer.from_pretrained(
new_model_name,
use_fast=True,
local_files_only=True
)
print(f"New tokenizer loaded! (local + fast)")
except:
tokenizer = AutoTokenizer.from_pretrained(
new_model_name,
use_fast=True
)
print(f"New tokenizer loaded! (network + fast)")
# Load new model
# Load new model (optimized)
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name)
try:
# Ultra-fast model switching
from transformers import AutoConfig
config = AutoConfig.from_pretrained(new_model_name, local_files_only=True)
model = AutoModel.from_pretrained(
new_model_name,
config=config,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True,
_fast_init=True
)
print(f"New model loaded! (ultra-fast)")
except:
try:
model = AutoModel.from_pretrained(
new_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True
)
print(f"New model loaded! (optimized)")
except:
try:
model = AutoModel.from_pretrained(
new_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
print(f"New model loaded! (network + optimized)")
except:
model = AutoModel.from_pretrained(new_model_name)
print(f"New model loaded! (standard)")
model.to(device)
model.eval()
@@ -640,19 +762,27 @@ def create_hnsw_embedding_server(
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
logger.info("Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
logger.info("Cleared MPS cache")
# Update model name
model_name = new_model_name
# Re-detect embedding mode based on new model name
if model_name.startswith("text-embedding-"):
embedding_mode = "openai"
logger.info(f"Auto-detected embedding mode: openai for {model_name}")
else:
embedding_mode = "sentence-transformers"
logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}")
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
logger.info("Memory cleanup completed")
response = ["SUCCESS"]
print(
@@ -664,6 +794,32 @@ def create_hnsw_embedding_server(
socket.send(msgpack.packb(response))
continue
# Handle direct text embedding request (for OpenAI and sentence-transformers)
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings) and NOT a control message
if (all(isinstance(item, str) for item in request_payload) and
not request_payload[0].startswith("__")):
logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode")
try:
if embedding_mode == "openai":
from leann.api import compute_embeddings_openai
embeddings = compute_embeddings_openai(request_payload, model_name)
else:
# sentence-transformers mode - compute directly
with timer(f"Direct text embedding ({len(request_payload)} texts)"):
embeddings = process_batch(request_payload, [], [])
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
except Exception as e:
logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}")
socket.send(msgpack.packb([]))
continue
# Handle distance calculation requests
if (
isinstance(request_payload, list)
@@ -674,7 +830,7 @@ def create_hnsw_embedding_server(
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
print("DEBUG: Distance calculation request received")
logger.debug("Distance calculation request received")
print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}")
print(f" Passages loaded: {len(passages)}")
@@ -684,7 +840,7 @@ def create_hnsw_embedding_server(
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
print(f"DEBUG: Looking up passage ID {nid}")
logger.debug(f"Looking up passage ID {nid}")
try:
txtinfo = passages[nid]
if txtinfo is None:
@@ -804,29 +960,11 @@ def create_hnsw_embedding_server(
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
# Handle direct text embedding request (for OpenAI mode)
if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
print(f"Processing direct text embedding request for {len(request_payload)} texts")
try:
from leann.api import compute_embeddings_openai
embeddings = compute_embeddings_openai(request_payload, model_name)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds")
continue
except Exception as e:
print(f"ERROR: Failed to compute OpenAI embeddings: {e}")
socket.send(msgpack.packb([]))
continue
# Standard embedding request (passage ID lookup)
if (
@@ -945,10 +1083,10 @@ def create_hnsw_embedding_server(
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
print("ZMQ socket timeout, continuing to listen")
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
print(f"Error in ZMQ server loop: {e}")

View File

@@ -20,7 +20,8 @@ from .chat import get_llm
def compute_embeddings(
chunks: List[str],
model_name: str,
mode: str = "sentence-transformers"
mode: str = "sentence-transformers",
use_server: bool = True
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -32,6 +33,7 @@ def compute_embeddings(
- "sentence-transformers": Use sentence-transformers library (default)
- "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API
use_server: Whether to use embedding server (True for search, False for build)
Returns:
numpy array of embeddings
@@ -45,13 +47,79 @@ def compute_embeddings(
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(chunks, model_name)
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
else:
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers library."""
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
chunks: List of text chunks to embed
model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
"""
if not use_server:
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
server_started = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
@@ -64,7 +132,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU
@@ -240,7 +308,7 @@ class LeannBuilder:
raise ValueError("No chunks added.")
if self.dimensions is None:
self.dimensions = len(
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode)[0]
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
)
path = Path(index_path)
index_dir = path.parent
@@ -267,7 +335,7 @@ class LeannBuilder:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.embedding_mode
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}

View File

@@ -200,7 +200,27 @@ class EmbeddingServerManager:
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if not model_matches:
if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
)
return True
else:
print(
f"✅ Existing server already using correct model: {model_name}"
)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if not model_matches:
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
else:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):