fix: same embedding logic

This commit is contained in:
Andy Lee
2025-07-21 20:12:40 -07:00
parent f47f76d6d7
commit 54155e8b10
5 changed files with 558 additions and 1333 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -9,9 +9,6 @@ import numpy as np
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional, Literal from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field from dataclasses import dataclass, field
import uuid
import torch
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from .chat import get_llm from .chat import get_llm
@@ -22,7 +19,7 @@ def compute_embeddings(
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
use_server: bool = True, use_server: bool = True,
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx', port: int = 5557,
) -> np.ndarray: ) -> np.ndarray:
""" """
Computes embeddings using different backends. Computes embeddings using different backends.
@@ -39,254 +36,60 @@ def compute_embeddings(
Returns: Returns:
numpy array of embeddings numpy array of embeddings
""" """
# Override mode for backward compatibility if use_server:
if use_mlx: # Use embedding server (for search/query)
mode = "mlx" return compute_embeddings_via_server(chunks, model_name, port=port)
# Auto-detect mode based on model name if not explicitly set
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
mode = "openai"
if mode == "mlx":
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
chunks, model_name, use_server=use_server
)
else: else:
raise ValueError( # Use direct computation (for build_index)
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai" from .embedding_compute import (
compute_embeddings as compute_embeddings_direct,
)
return compute_embeddings_direct(
chunks,
model_name,
mode=mode,
) )
def compute_embeddings_sentence_transformers( def compute_embeddings_via_server(
chunks: List[str], model_name: str, use_server: bool = True chunks: List[str], model_name: str, port: int
) -> np.ndarray: ) -> np.ndarray:
"""Computes embeddings using sentence-transformers. """Computes embeddings using sentence-transformers.
Args: Args:
chunks: List of text chunks to embed chunks: List of text chunks to embed
model_name: Name of the sentence transformer model model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
""" """
if not use_server:
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print( print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
) )
import zmq
import msgpack
import numpy as np
# Use embedding server for sentence-transformers too # Connect to embedding server
# This avoids loading the model twice (once in API, once in server) context = zmq.Context()
try: socket = context.socket(zmq.REQ)
# Import ZMQ client functionality and server manager socket.connect(f"tcp://localhost:{port}")
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Ensure embedding server is running # Send chunks to server for embedding computation
port = 5557 request = chunks
server_manager = EmbeddingServerManager( socket.send(msgpack.packb(request))
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server( # Receive embeddings from server
port=port, response = socket.recv()
model_name=model_name, embeddings_list = msgpack.unpackb(response)
embedding_mode="sentence-transformers",
enable_warmup=False,
)
if not server_started: # Convert back to numpy array
raise RuntimeError(f"Failed to start embedding server on port {actual_port}") embeddings = np.array(embeddings_list, dtype=np.float32)
# Use the actual port for connection socket.close()
port = actual_port context.term()
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(
chunks: List[str], model_name: str
) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise RuntimeError(
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
# give use an warning if OOM here means we need to turn down the batch size
embeddings = model.encode(
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
)
return embeddings return embeddings
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using OpenAI API."""
try:
import openai
import os
except ImportError as e:
raise RuntimeError(
"openai not available. Install with: uv pip install openai"
) from e
# Get API key from environment
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
)
# OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(chunks), max_batch_size)
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
except ImportError:
# Fallback without progress bar
batch_iterator = range(0, len(chunks), max_batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_chunks)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
@dataclass @dataclass
class SearchResult: class SearchResult:
id: str id: str
@@ -347,8 +150,6 @@ class LeannBuilder:
self.dimensions = dimensions self.dimensions = dimensions
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs self.backend_kwargs = backend_kwargs
if 'mlx' in self.embedding_model:
self.embedding_mode = "mlx"
self.chunks: List[Dict[str, Any]] = [] self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
@@ -380,7 +181,10 @@ class LeannBuilder:
with open(passages_file, "w", encoding="utf-8") as f: with open(passages_file, "w", encoding="utf-8") as f:
try: try:
from tqdm import tqdm from tqdm import tqdm
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError: except ImportError:
chunk_iterator = self.chunks chunk_iterator = self.chunks
@@ -401,7 +205,11 @@ class LeannBuilder:
pickle.dump(offset_map, f) pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings( embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
port=5557,
) )
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}

View File

@@ -0,0 +1,272 @@
"""
Unified embedding computation module
Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import numpy as np
import torch
from typing import List
import logging
logger = logging.getLogger(__name__)
def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers"
) -> np.ndarray:
"""
Unified embedding computation entry point
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(texts, model_name)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer
Preserves all optimization parameters to ensure consistency with original embedding_server
Args:
texts: List of texts to compute embeddings for
model_name: SentenceTransformer model name
use_fp16: Whether to use FP16 precision
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
print(
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
from sentence_transformers import SentenceTransformer
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"INFO: Using device: {device}")
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True, # Skip weight initialization checks for faster loading
}
tokenizer_kwargs = {
"use_fast": True, # Use fast tokenizer for better runtime performance
}
# Load SentenceTransformer (try local first, then network)
print(f"INFO: Loading SentenceTransformer model: {model_name}")
try:
# Try local loading (avoid network delays)
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(
f"FP16 or compile optimization failed, continuing with default settings: {e}"
)
# Compute embeddings (using SentenceTransformer's optimized implementation)
print("INFO: Starting embedding computation...")
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False, # Keep consistent with original API behavior
device=device,
)
print(
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
"""Compute embeddings using OpenAI API"""
try:
import openai
import os
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
# OpenAI has limits on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i : i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)

View File

@@ -4,10 +4,8 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import zmq
import msgpack
from pathlib import Path from pathlib import Path
from typing import Optional, Dict from typing import Optional
import select import select
import psutil import psutil
@@ -19,7 +17,7 @@ def _check_port(port: int) -> bool:
def _check_process_matches_config( def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str = None port: int, expected_model: str, expected_passages_file: str
) -> bool: ) -> bool:
""" """
Check if the process using the port matches our expected model and passages file. Check if the process using the port matches our expected model and passages file.
@@ -34,7 +32,9 @@ def _check_process_matches_config(
if not cmdline: if not cmdline:
continue continue
return _check_cmdline_matches_config(cmdline, port, expected_model, expected_passages_file) return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
print(f"DEBUG: No process found listening on port {port}") print(f"DEBUG: No process found listening on port {port}")
return False return False
@@ -57,18 +57,21 @@ def _is_process_listening_on_port(proc, port: int) -> bool:
def _check_cmdline_matches_config( def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str = None cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool: ) -> bool:
"""Check if command line matches our expected configuration.""" """Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline) cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}") print(f"DEBUG: Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server # Check if it's our embedding server
is_embedding_server = any(server_type in cmdline_str for server_type in [ is_embedding_server = any(
"embedding_server", server_type in cmdline_str
"leann_backend_diskann.embedding_server", for server_type in [
"leann_backend_hnsw.hnsw_embedding_server" "embedding_server",
]) "leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server: if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server") print(f"DEBUG: Process on port {port} is not our embedding server")
@@ -81,7 +84,9 @@ def _check_cmdline_matches_config(
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file) passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches result = model_matches and passages_matches
print(f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}") print(
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result return result
@@ -98,11 +103,8 @@ def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
return actual_model == expected_model return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None) -> bool: def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file.""" """Check if the command line contains the expected passages file."""
if not expected_passages_file:
return True # No passages file expected
if "--passages-file" not in cmdline: if "--passages-file" not in cmdline:
return False # Expected but not found return False # Expected but not found
@@ -117,7 +119,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None
def _find_compatible_port_or_next_available( def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str = None, max_attempts: int = 100 start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]: ) -> tuple[int, bool]:
""" """
Find a port that either has a compatible server or is available. Find a port that either has a compatible server or is available.
@@ -177,9 +179,13 @@ class EmbeddingServerManager:
tuple[bool, int]: (success, actual_port_used) tuple[bool, int]: (success, actual_port_used)
""" """
passages_file = kwargs.get("passages_file") passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible running server # Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file): if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, (
"a compatible running server should set server_port"
)
return True, self.server_port return True, self.server_port
# Find available port (compatible or free) # Find available port (compatible or free)
@@ -203,20 +209,29 @@ class EmbeddingServerManager:
# Start new server # Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool: def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server.""" """Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port): if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False return False
if _check_process_matches_config(self.server_port, model_name, passages_file): if _check_process_matches_config(self.server_port, model_name, passages_file):
print(f"✅ Existing server process (PID {self.server_process.pid}) is compatible") print(
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
)
return True return True
print("⚠️ Existing server process is incompatible. Stopping it...") print("⚠️ Existing server process is incompatible. Should start a new server.")
self.stop_server()
return False return False
def _start_new_server(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> tuple[bool, int]: def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port.""" """Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...") print(f"INFO: Starting embedding server on port {port}...")
@@ -229,20 +244,24 @@ class EmbeddingServerManager:
print(f"❌ ERROR: Failed to start embedding server: {e}") print(f"❌ ERROR: Failed to start embedding server: {e}")
return False, port return False, port
def _build_server_command(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> list: def _build_server_command(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> list:
"""Build the command to start the embedding server.""" """Build the command to start the embedding server."""
command = [ command = [
sys.executable, "-m", self.backend_module_name, sys.executable,
"--zmq-port", str(port), "-m",
"--model-name", model_name, self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
] ]
if kwargs.get("passages_file"): if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])]) command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers": if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode]) command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("enable_warmup") is False:
command.extend(["--disable-warmup"])
return command return command
@@ -252,9 +271,14 @@ class EmbeddingServerManager:
print(f"INFO: Command: {' '.join(command)}") print(f"INFO: Command: {' '.join(command)}")
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, cwd=project_root, command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=project_root,
text=True, encoding="utf-8", bufsize=1, universal_newlines=True, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
bufsize=1,
universal_newlines=True,
) )
self.server_port = port self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}") print(f"INFO: Server process started with PID: {self.server_process.pid}")
@@ -323,14 +347,18 @@ class EmbeddingServerManager:
self.server_process = None self.server_process = None
return return
print(f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}...") print(
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
self.server_process.terminate() self.server_process.terminate()
try: try:
self.server_process.wait(timeout=5) self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.") print(f"INFO: Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
print(f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it.") print(
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
)
self.server_process.kill() self.server_process.kill()
self.server_process = None self.server_process = None

View File

@@ -43,7 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail." "WARNING: embedding_model not found in meta.json. Recompute will fail."
) )
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name backend_module_name=backend_module_name
) )
@@ -57,10 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
with open(meta_path, "r", encoding="utf-8") as f: with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
def _ensure_server_running( def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs self, passages_source_file: str, port: int, **kwargs
) -> None: ) -> int:
""" """
Ensures the embedding server is running if recompute is needed. Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses. This is a helper for subclasses.
@@ -81,11 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}") raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
# Update the port information for future use return actual_port
if hasattr(self, '_actual_server_port'):
self._actual_server_port = actual_port
def compute_query_embedding( def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
@@ -105,8 +103,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
if use_server_if_available: if use_server_if_available:
try: try:
# Ensure we have a server with passages_file for compatibility # Ensure we have a server with passages_file for compatibility
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json" passages_source_file = (
self._ensure_server_running(str(passages_source_file), zmq_port) self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[ return self._compute_embedding_via_server([query], zmq_port)[
0:1 0:1