fix: same embedding logic
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -9,9 +9,6 @@ import numpy as np
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional, Literal
|
from typing import List, Dict, Any, Optional, Literal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import uuid
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
@@ -22,7 +19,7 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
port: int = 5557,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -39,254 +36,60 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
numpy array of embeddings
|
numpy array of embeddings
|
||||||
"""
|
"""
|
||||||
# Override mode for backward compatibility
|
if use_server:
|
||||||
if use_mlx:
|
# Use embedding server (for search/query)
|
||||||
mode = "mlx"
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||||
|
|
||||||
# Auto-detect mode based on model name if not explicitly set
|
|
||||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
|
||||||
mode = "openai"
|
|
||||||
|
|
||||||
if mode == "mlx":
|
|
||||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
|
||||||
elif mode == "openai":
|
|
||||||
return compute_embeddings_openai(chunks, model_name)
|
|
||||||
elif mode == "sentence-transformers":
|
|
||||||
return compute_embeddings_sentence_transformers(
|
|
||||||
chunks, model_name, use_server=use_server
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
# Use direct computation (for build_index)
|
||||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
from .embedding_compute import (
|
||||||
|
compute_embeddings as compute_embeddings_direct,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_embeddings_direct(
|
||||||
|
chunks,
|
||||||
|
model_name,
|
||||||
|
mode=mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(
|
def compute_embeddings_via_server(
|
||||||
chunks: List[str], model_name: str, use_server: bool = True
|
chunks: List[str], model_name: str, port: int
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks: List of text chunks to embed
|
chunks: List of text chunks to embed
|
||||||
model_name: Name of the sentence transformer model
|
model_name: Name of the sentence transformer model
|
||||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
|
||||||
"""
|
"""
|
||||||
if not use_server:
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
|
||||||
)
|
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
)
|
)
|
||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Use embedding server for sentence-transformers too
|
# Connect to embedding server
|
||||||
# This avoids loading the model twice (once in API, once in server)
|
context = zmq.Context()
|
||||||
try:
|
socket = context.socket(zmq.REQ)
|
||||||
# Import ZMQ client functionality and server manager
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
import zmq
|
|
||||||
import msgpack
|
|
||||||
import numpy as np
|
|
||||||
from .embedding_server_manager import EmbeddingServerManager
|
|
||||||
|
|
||||||
# Ensure embedding server is running
|
# Send chunks to server for embedding computation
|
||||||
port = 5557
|
request = chunks
|
||||||
server_manager = EmbeddingServerManager(
|
socket.send(msgpack.packb(request))
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
|
|
||||||
server_started, actual_port = server_manager.start_server(
|
# Receive embeddings from server
|
||||||
port=port,
|
response = socket.recv()
|
||||||
model_name=model_name,
|
embeddings_list = msgpack.unpackb(response)
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
enable_warmup=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not server_started:
|
# Convert back to numpy array
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
|
|
||||||
# Use the actual port for connection
|
socket.close()
|
||||||
port = actual_port
|
context.term()
|
||||||
|
|
||||||
# Connect to embedding server
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send chunks to server for embedding computation
|
|
||||||
request = chunks
|
|
||||||
socket.send(msgpack.packb(request))
|
|
||||||
|
|
||||||
# Receive embeddings from server
|
|
||||||
response = socket.recv()
|
|
||||||
embeddings_list = msgpack.unpackb(response)
|
|
||||||
|
|
||||||
# Convert back to numpy array
|
|
||||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Fallback to direct sentence-transformers if server connection fails
|
|
||||||
print(
|
|
||||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
|
||||||
)
|
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_embeddings_sentence_transformers_direct(
|
|
||||||
chunks: List[str], model_name: str
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Direct sentence-transformers computation (fallback)."""
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Load model using sentence-transformers
|
|
||||||
model = SentenceTransformer(model_name)
|
|
||||||
|
|
||||||
model = model.half()
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
|
||||||
)
|
|
||||||
# use acclerater GPU or MAC GPU
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model = model.to("cuda")
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
model = model.to("mps")
|
|
||||||
|
|
||||||
# Generate embeddings
|
|
||||||
# give use an warning if OOM here means we need to turn down the batch size
|
|
||||||
embeddings = model.encode(
|
|
||||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
|
||||||
"""Computes embeddings using OpenAI API."""
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
import os
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"openai not available. Install with: uv pip install openai"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Get API key from environment
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=api_key)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# OpenAI has a limit on batch size and input length
|
|
||||||
max_batch_size = 100 # Conservative batch size
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
|
||||||
batch_range = range(0, len(chunks), max_batch_size)
|
|
||||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
|
||||||
except ImportError:
|
|
||||||
# Fallback without progress bar
|
|
||||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_chunks = chunks[i:i + max_batch_size]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
print(
|
|
||||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
|
||||||
)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
|
||||||
"""Computes embeddings using an MLX model."""
|
|
||||||
try:
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx_lm.utils import load
|
|
||||||
from tqdm import tqdm
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
|
||||||
model, tokenizer = load(model_name)
|
|
||||||
|
|
||||||
# Process chunks in batches with progress bar
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
|
||||||
except ImportError:
|
|
||||||
batch_iterator = range(0, len(chunks), batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_chunks = chunks[i:i + batch_size]
|
|
||||||
|
|
||||||
# Tokenize all chunks in the batch
|
|
||||||
batch_token_ids = []
|
|
||||||
for chunk in batch_chunks:
|
|
||||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
|
||||||
batch_token_ids.append(token_ids)
|
|
||||||
|
|
||||||
# Pad sequences to the same length for batch processing
|
|
||||||
max_length = max(len(ids) for ids in batch_token_ids)
|
|
||||||
padded_token_ids = []
|
|
||||||
for token_ids in batch_token_ids:
|
|
||||||
# Pad with tokenizer.pad_token_id or 0
|
|
||||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
|
||||||
padded_token_ids.append(padded)
|
|
||||||
|
|
||||||
# Convert to MLX array with batch dimension
|
|
||||||
input_ids = mx.array(padded_token_ids)
|
|
||||||
|
|
||||||
# Get embeddings for the batch
|
|
||||||
embeddings = model(input_ids)
|
|
||||||
|
|
||||||
# Mean pooling for each sequence in the batch
|
|
||||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
|
||||||
|
|
||||||
# Convert batch embeddings to numpy
|
|
||||||
for j in range(len(batch_chunks)):
|
|
||||||
pooled_list = pooled[j].tolist() # Convert to list
|
|
||||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
|
||||||
all_embeddings.append(pooled_numpy)
|
|
||||||
|
|
||||||
# Stack numpy arrays
|
|
||||||
return np.stack(all_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
id: str
|
id: str
|
||||||
@@ -347,8 +150,6 @@ class LeannBuilder:
|
|||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
if 'mlx' in self.embedding_model:
|
|
||||||
self.embedding_mode = "mlx"
|
|
||||||
self.chunks: List[Dict[str, Any]] = []
|
self.chunks: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
@@ -380,7 +181,10 @@ class LeannBuilder:
|
|||||||
with open(passages_file, "w", encoding="utf-8") as f:
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
|
||||||
|
chunk_iterator = tqdm(
|
||||||
|
self.chunks, desc="Writing passages", unit="chunk"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
chunk_iterator = self.chunks
|
chunk_iterator = self.chunks
|
||||||
|
|
||||||
@@ -401,7 +205,11 @@ class LeannBuilder:
|
|||||||
pickle.dump(offset_map, f)
|
pickle.dump(offset_map, f)
|
||||||
texts_to_embed = [c["text"] for c in self.chunks]
|
texts_to_embed = [c["text"] for c in self.chunks]
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
texts_to_embed,
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
port=5557,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
|
|||||||
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
Unified embedding computation module
|
||||||
|
Consolidates all embedding computation logic using SentenceTransformer
|
||||||
|
Preserves all optimization parameters to ensure performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings(
|
||||||
|
texts: List[str], model_name: str, mode: str = "sentence-transformers"
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Unified embedding computation entry point
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Model name
|
||||||
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
if mode == "sentence-transformers":
|
||||||
|
return compute_embeddings_sentence_transformers(texts, model_name)
|
||||||
|
elif mode == "openai":
|
||||||
|
return compute_embeddings_openai(texts, model_name)
|
||||||
|
elif mode == "mlx":
|
||||||
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_sentence_transformers(
|
||||||
|
texts: List[str],
|
||||||
|
model_name: str,
|
||||||
|
use_fp16: bool = True,
|
||||||
|
device: str = "auto",
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using SentenceTransformer
|
||||||
|
Preserves all optimization parameters to ensure consistency with original embedding_server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: SentenceTransformer model name
|
||||||
|
use_fp16: Whether to use FP16 precision
|
||||||
|
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# Auto-detect device
|
||||||
|
if device == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
print(f"INFO: Using device: {device}")
|
||||||
|
|
||||||
|
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
||||||
|
model_kwargs = {
|
||||||
|
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"_fast_init": True, # Skip weight initialization checks for faster loading
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"use_fast": True, # Use fast tokenizer for better runtime performance
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load SentenceTransformer (try local first, then network)
|
||||||
|
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try local loading (avoid network delays)
|
||||||
|
model_kwargs["local_files_only"] = True
|
||||||
|
tokenizer_kwargs["local_files_only"] = True
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
print("✅ Model loaded successfully! (local + optimized)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Local loading failed ({e}), trying network download...")
|
||||||
|
# Fallback to network loading
|
||||||
|
model_kwargs["local_files_only"] = False
|
||||||
|
tokenizer_kwargs["local_files_only"] = False
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
print("✅ Model loaded successfully! (network + optimized)")
|
||||||
|
|
||||||
|
# Apply additional optimizations (if supported)
|
||||||
|
if use_fp16 and device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
||||||
|
print("INFO: Starting embedding computation...")
|
||||||
|
|
||||||
|
embeddings = model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=batch_size,
|
||||||
|
show_progress_bar=False, # Don't show progress bar in server environment
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False, # Keep consistent with original API behavior
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||||
|
"""Compute embeddings using OpenAI API"""
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
import os
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key=api_key)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI has limits on batch size and input length
|
||||||
|
max_batch_size = 100 # Conservative batch size
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
print(
|
||||||
|
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_mlx(
|
||||||
|
chunks: List[str], model_name: str, batch_size: int = 16
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Computes embeddings using an MLX model."""
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
model, tokenizer = load(model_name)
|
||||||
|
|
||||||
|
# Process chunks in batches with progress bar
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
batch_iterator = range(0, len(chunks), batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
|
|
||||||
|
# Tokenize all chunks in the batch
|
||||||
|
batch_token_ids = []
|
||||||
|
for chunk in batch_chunks:
|
||||||
|
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||||
|
batch_token_ids.append(token_ids)
|
||||||
|
|
||||||
|
# Pad sequences to the same length for batch processing
|
||||||
|
max_length = max(len(ids) for ids in batch_token_ids)
|
||||||
|
padded_token_ids = []
|
||||||
|
for token_ids in batch_token_ids:
|
||||||
|
# Pad with tokenizer.pad_token_id or 0
|
||||||
|
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||||
|
padded_token_ids.append(padded)
|
||||||
|
|
||||||
|
# Convert to MLX array with batch dimension
|
||||||
|
input_ids = mx.array(padded_token_ids)
|
||||||
|
|
||||||
|
# Get embeddings for the batch
|
||||||
|
embeddings = model(input_ids)
|
||||||
|
|
||||||
|
# Mean pooling for each sequence in the batch
|
||||||
|
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||||
|
|
||||||
|
# Convert batch embeddings to numpy
|
||||||
|
for j in range(len(batch_chunks)):
|
||||||
|
pooled_list = pooled[j].tolist() # Convert to list
|
||||||
|
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||||
|
all_embeddings.append(pooled_numpy)
|
||||||
|
|
||||||
|
# Stack numpy arrays
|
||||||
|
return np.stack(all_embeddings)
|
||||||
@@ -4,10 +4,8 @@ import atexit
|
|||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import zmq
|
|
||||||
import msgpack
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict
|
from typing import Optional
|
||||||
import select
|
import select
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
@@ -19,7 +17,7 @@ def _check_port(port: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _check_process_matches_config(
|
def _check_process_matches_config(
|
||||||
port: int, expected_model: str, expected_passages_file: str = None
|
port: int, expected_model: str, expected_passages_file: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the process using the port matches our expected model and passages file.
|
Check if the process using the port matches our expected model and passages file.
|
||||||
@@ -34,7 +32,9 @@ def _check_process_matches_config(
|
|||||||
if not cmdline:
|
if not cmdline:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return _check_cmdline_matches_config(cmdline, port, expected_model, expected_passages_file)
|
return _check_cmdline_matches_config(
|
||||||
|
cmdline, port, expected_model, expected_passages_file
|
||||||
|
)
|
||||||
|
|
||||||
print(f"DEBUG: No process found listening on port {port}")
|
print(f"DEBUG: No process found listening on port {port}")
|
||||||
return False
|
return False
|
||||||
@@ -57,18 +57,21 @@ def _is_process_listening_on_port(proc, port: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _check_cmdline_matches_config(
|
def _check_cmdline_matches_config(
|
||||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str = None
|
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if command line matches our expected configuration."""
|
"""Check if command line matches our expected configuration."""
|
||||||
cmdline_str = " ".join(cmdline)
|
cmdline_str = " ".join(cmdline)
|
||||||
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
# Check if it's our embedding server
|
# Check if it's our embedding server
|
||||||
is_embedding_server = any(server_type in cmdline_str for server_type in [
|
is_embedding_server = any(
|
||||||
"embedding_server",
|
server_type in cmdline_str
|
||||||
"leann_backend_diskann.embedding_server",
|
for server_type in [
|
||||||
"leann_backend_hnsw.hnsw_embedding_server"
|
"embedding_server",
|
||||||
])
|
"leann_backend_diskann.embedding_server",
|
||||||
|
"leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if not is_embedding_server:
|
if not is_embedding_server:
|
||||||
print(f"DEBUG: Process on port {port} is not our embedding server")
|
print(f"DEBUG: Process on port {port} is not our embedding server")
|
||||||
@@ -81,7 +84,9 @@ def _check_cmdline_matches_config(
|
|||||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
result = model_matches and passages_matches
|
result = model_matches and passages_matches
|
||||||
print(f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}")
|
print(
|
||||||
|
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -98,11 +103,8 @@ def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
|||||||
return actual_model == expected_model
|
return actual_model == expected_model
|
||||||
|
|
||||||
|
|
||||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None) -> bool:
|
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||||
"""Check if the command line contains the expected passages file."""
|
"""Check if the command line contains the expected passages file."""
|
||||||
if not expected_passages_file:
|
|
||||||
return True # No passages file expected
|
|
||||||
|
|
||||||
if "--passages-file" not in cmdline:
|
if "--passages-file" not in cmdline:
|
||||||
return False # Expected but not found
|
return False # Expected but not found
|
||||||
|
|
||||||
@@ -117,7 +119,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None
|
|||||||
|
|
||||||
|
|
||||||
def _find_compatible_port_or_next_available(
|
def _find_compatible_port_or_next_available(
|
||||||
start_port: int, model_name: str, passages_file: str = None, max_attempts: int = 100
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
) -> tuple[int, bool]:
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Find a port that either has a compatible server or is available.
|
Find a port that either has a compatible server or is available.
|
||||||
@@ -177,9 +179,13 @@ class EmbeddingServerManager:
|
|||||||
tuple[bool, int]: (success, actual_port_used)
|
tuple[bool, int]: (success, actual_port_used)
|
||||||
"""
|
"""
|
||||||
passages_file = kwargs.get("passages_file")
|
passages_file = kwargs.get("passages_file")
|
||||||
|
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||||
|
|
||||||
# Check if we have a compatible running server
|
# Check if we have a compatible running server
|
||||||
if self._has_compatible_running_server(model_name, passages_file):
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
|
assert self.server_port is not None, (
|
||||||
|
"a compatible running server should set server_port"
|
||||||
|
)
|
||||||
return True, self.server_port
|
return True, self.server_port
|
||||||
|
|
||||||
# Find available port (compatible or free)
|
# Find available port (compatible or free)
|
||||||
@@ -203,20 +209,29 @@ class EmbeddingServerManager:
|
|||||||
# Start new server
|
# Start new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
def _has_compatible_running_server(
|
||||||
|
self, model_name: str, passages_file: str
|
||||||
|
) -> bool:
|
||||||
"""Check if we have a compatible running server."""
|
"""Check if we have a compatible running server."""
|
||||||
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
if not (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
print(f"✅ Existing server process (PID {self.server_process.pid}) is compatible")
|
print(
|
||||||
|
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print("⚠️ Existing server process is incompatible. Stopping it...")
|
print("⚠️ Existing server process is incompatible. Should start a new server.")
|
||||||
self.stop_server()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _start_new_server(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> tuple[bool, int]:
|
def _start_new_server(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
print(f"INFO: Starting embedding server on port {port}...")
|
print(f"INFO: Starting embedding server on port {port}...")
|
||||||
|
|
||||||
@@ -229,20 +244,24 @@ class EmbeddingServerManager:
|
|||||||
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
def _build_server_command(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> list:
|
def _build_server_command(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> list:
|
||||||
"""Build the command to start the embedding server."""
|
"""Build the command to start the embedding server."""
|
||||||
command = [
|
command = [
|
||||||
sys.executable, "-m", self.backend_module_name,
|
sys.executable,
|
||||||
"--zmq-port", str(port),
|
"-m",
|
||||||
"--model-name", model_name,
|
self.backend_module_name,
|
||||||
|
"--zmq-port",
|
||||||
|
str(port),
|
||||||
|
"--model-name",
|
||||||
|
model_name,
|
||||||
]
|
]
|
||||||
|
|
||||||
if kwargs.get("passages_file"):
|
if kwargs.get("passages_file"):
|
||||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||||
if embedding_mode != "sentence-transformers":
|
if embedding_mode != "sentence-transformers":
|
||||||
command.extend(["--embedding-mode", embedding_mode])
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
if kwargs.get("enable_warmup") is False:
|
|
||||||
command.extend(["--disable-warmup"])
|
|
||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
@@ -252,9 +271,14 @@ class EmbeddingServerManager:
|
|||||||
print(f"INFO: Command: {' '.join(command)}")
|
print(f"INFO: Command: {' '.join(command)}")
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command, cwd=project_root,
|
command,
|
||||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
cwd=project_root,
|
||||||
text=True, encoding="utf-8", bufsize=1, universal_newlines=True,
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
bufsize=1,
|
||||||
|
universal_newlines=True,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||||
@@ -323,14 +347,18 @@ class EmbeddingServerManager:
|
|||||||
self.server_process = None
|
self.server_process = None
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}...")
|
print(
|
||||||
|
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
|
)
|
||||||
self.server_process.terminate()
|
self.server_process.terminate()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=5)
|
self.server_process.wait(timeout=5)
|
||||||
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
print(f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it.")
|
print(
|
||||||
|
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
|
)
|
||||||
self.server_process.kill()
|
self.server_process.kill()
|
||||||
|
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name
|
backend_module_name=backend_module_name
|
||||||
)
|
)
|
||||||
@@ -57,10 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
with open(meta_path, "r", encoding="utf-8") as f:
|
with open(meta_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(
|
||||||
self, passages_source_file: str, port: int, **kwargs
|
self, passages_source_file: str, port: int, **kwargs
|
||||||
) -> None:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Ensures the embedding server is running if recompute is needed.
|
Ensures the embedding server is running if recompute is needed.
|
||||||
This is a helper for subclasses.
|
This is a helper for subclasses.
|
||||||
@@ -81,11 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(
|
||||||
|
f"Failed to start embedding server on port {actual_port}"
|
||||||
|
)
|
||||||
|
|
||||||
# Update the port information for future use
|
return actual_port
|
||||||
if hasattr(self, '_actual_server_port'):
|
|
||||||
self._actual_server_port = actual_port
|
|
||||||
|
|
||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||||
@@ -105,8 +103,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
if use_server_if_available:
|
if use_server_if_available:
|
||||||
try:
|
try:
|
||||||
# Ensure we have a server with passages_file for compatibility
|
# Ensure we have a server with passages_file for compatibility
|
||||||
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
|
passages_source_file = (
|
||||||
self._ensure_server_running(str(passages_source_file), zmq_port)
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
)
|
||||||
|
zmq_port = self._ensure_server_running(
|
||||||
|
str(passages_source_file), zmq_port
|
||||||
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
0:1
|
0:1
|
||||||
|
|||||||
Reference in New Issue
Block a user