fix: same embedding logic

This commit is contained in:
Andy Lee
2025-07-21 20:12:40 -07:00
parent f47f76d6d7
commit 54155e8b10
5 changed files with 558 additions and 1333 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -9,9 +9,6 @@ import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
import uuid
import torch
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
@@ -22,7 +19,7 @@ def compute_embeddings(
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
port: int = 5557,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -39,254 +36,60 @@ def compute_embeddings(
Returns:
numpy array of embeddings
"""
# Override mode for backward compatibility
if use_mlx:
mode = "mlx"
# Auto-detect mode based on model name if not explicitly set
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
mode = "openai"
if mode == "mlx":
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
chunks, model_name, use_server=use_server
)
if use_server:
# Use embedding server (for search/query)
return compute_embeddings_via_server(chunks, model_name, port=port)
else:
raise ValueError(
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
# Use direct computation (for build_index)
from .embedding_compute import (
compute_embeddings as compute_embeddings_direct,
)
return compute_embeddings_direct(
chunks,
model_name,
mode=mode,
)
def compute_embeddings_sentence_transformers(
chunks: List[str], model_name: str, use_server: bool = True
def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
chunks: List of text chunks to embed
model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
"""
if not use_server:
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
import zmq
import msgpack
import numpy as np
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
server_started, actual_port = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
# Use the actual port for connection
port = actual_port
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(
chunks: List[str], model_name: str
) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise RuntimeError(
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
# give use an warning if OOM here means we need to turn down the batch size
embeddings = model.encode(
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
)
socket.close()
context.term()
return embeddings
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using OpenAI API."""
try:
import openai
import os
except ImportError as e:
raise RuntimeError(
"openai not available. Install with: uv pip install openai"
) from e
# Get API key from environment
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
)
# OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(chunks), max_batch_size)
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
except ImportError:
# Fallback without progress bar
batch_iterator = range(0, len(chunks), max_batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_chunks)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
@dataclass
class SearchResult:
id: str
@@ -347,8 +150,6 @@ class LeannBuilder:
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
if 'mlx' in self.embedding_model:
self.embedding_mode = "mlx"
self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
@@ -380,10 +181,13 @@ class LeannBuilder:
with open(passages_file, "w", encoding="utf-8") as f:
try:
from tqdm import tqdm
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError:
chunk_iterator = self.chunks
for chunk in chunk_iterator:
offset = f.tell()
json.dump(
@@ -401,7 +205,11 @@ class LeannBuilder:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
port=5557,
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}

View File

@@ -0,0 +1,272 @@
"""
Unified embedding computation module
Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import numpy as np
import torch
from typing import List
import logging
logger = logging.getLogger(__name__)
def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers"
) -> np.ndarray:
"""
Unified embedding computation entry point
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(texts, model_name)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer
Preserves all optimization parameters to ensure consistency with original embedding_server
Args:
texts: List of texts to compute embeddings for
model_name: SentenceTransformer model name
use_fp16: Whether to use FP16 precision
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
print(
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
from sentence_transformers import SentenceTransformer
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"INFO: Using device: {device}")
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True, # Skip weight initialization checks for faster loading
}
tokenizer_kwargs = {
"use_fast": True, # Use fast tokenizer for better runtime performance
}
# Load SentenceTransformer (try local first, then network)
print(f"INFO: Loading SentenceTransformer model: {model_name}")
try:
# Try local loading (avoid network delays)
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(
f"FP16 or compile optimization failed, continuing with default settings: {e}"
)
# Compute embeddings (using SentenceTransformer's optimized implementation)
print("INFO: Starting embedding computation...")
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False, # Keep consistent with original API behavior
device=device,
)
print(
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
"""Compute embeddings using OpenAI API"""
try:
import openai
import os
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
# OpenAI has limits on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i : i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)

View File

@@ -4,10 +4,8 @@ import atexit
import socket
import subprocess
import sys
import zmq
import msgpack
from pathlib import Path
from typing import Optional, Dict
from typing import Optional
import select
import psutil
@@ -19,7 +17,7 @@ def _check_port(port: int) -> bool:
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str = None
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
@@ -29,16 +27,18 @@ def _check_process_matches_config(
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
cmdline = proc.info["cmdline"]
if not cmdline:
continue
return _check_cmdline_matches_config(cmdline, port, expected_model, expected_passages_file)
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
print(f"DEBUG: No process found listening on port {port}")
return False
except Exception as e:
print(f"WARNING: Could not check process on port {port}: {e}")
return False
@@ -57,31 +57,36 @@ def _is_process_listening_on_port(proc, port: int) -> bool:
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str = None
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(server_type in cmdline_str for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server"
])
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
print(f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}")
print(
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
@@ -89,27 +94,24 @@ def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None) -> bool:
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if not expected_passages_file:
return True # No passages file expected
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
@@ -117,7 +119,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str = None, max_attempts: int = 100
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
@@ -177,9 +179,13 @@ class EmbeddingServerManager:
tuple[bool, int]: (success, actual_port_used)
"""
passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, (
"a compatible running server should set server_port"
)
return True, self.server_port
# Find available port (compatible or free)
@@ -203,25 +209,34 @@ class EmbeddingServerManager:
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port):
if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
print(f"✅ Existing server process (PID {self.server_process.pid}) is compatible")
print(
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
)
return True
print("⚠️ Existing server process is incompatible. Stopping it...")
self.stop_server()
print("⚠️ Existing server process is incompatible. Should start a new server.")
return False
def _start_new_server(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> tuple[bool, int]:
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
@@ -229,20 +244,24 @@ class EmbeddingServerManager:
print(f"❌ ERROR: Failed to start embedding server: {e}")
return False, port
def _build_server_command(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> list:
def _build_server_command(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable, "-m", self.backend_module_name,
"--zmq-port", str(port),
"--model-name", model_name,
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("enable_warmup") is False:
command.extend(["--disable-warmup"])
return command
@@ -252,13 +271,18 @@ class EmbeddingServerManager:
print(f"INFO: Command: {' '.join(command)}")
self.server_process = subprocess.Popen(
command, cwd=project_root,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, encoding="utf-8", bufsize=1, universal_newlines=True,
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
bufsize=1,
universal_newlines=True,
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Use a lambda to avoid issues with bound methods
@@ -273,12 +297,12 @@ class EmbeddingServerManager:
print("✅ Embedding server is ready!")
threading.Thread(target=self._log_monitor, daemon=True).start()
return True, port
if self.server_process.poll() is not None:
print("❌ ERROR: Server terminated during startup.")
self._print_recent_output()
return False, port
time.sleep(wait_interval)
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
@@ -317,20 +341,24 @@ class EmbeddingServerManager:
"""Stops the embedding server process if it's running."""
if not self.server_process:
return
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
return
print(f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}...")
print(
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
print(f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it.")
print(
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
)
self.server_process.kill()
self.server_process = None

View File

@@ -43,7 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
)
@@ -57,10 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> None:
) -> int:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
@@ -71,7 +69,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
)
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
@@ -81,11 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
# Update the port information for future use
if hasattr(self, '_actual_server_port'):
self._actual_server_port = actual_port
raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
return actual_port
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
@@ -105,9 +103,13 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
if use_server_if_available:
try:
# Ensure we have a server with passages_file for compatibility
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
self._ensure_server_running(str(passages_source_file), zmq_port)
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape