Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG

This commit is contained in:
yichuan520030910320
2025-07-17 22:29:39 -07:00
5 changed files with 348 additions and 110 deletions

View File

@@ -20,7 +20,8 @@ from .chat import get_llm
def compute_embeddings(
chunks: List[str],
model_name: str,
mode: str = "sentence-transformers"
mode: str = "sentence-transformers",
use_server: bool = True
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -32,6 +33,7 @@ def compute_embeddings(
- "sentence-transformers": Use sentence-transformers library (default)
- "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API
use_server: Whether to use embedding server (True for search, False for build)
Returns:
numpy array of embeddings
@@ -45,13 +47,79 @@ def compute_embeddings(
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(chunks, model_name)
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
else:
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers library."""
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
chunks: List of text chunks to embed
model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
"""
if not use_server:
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
server_started = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
@@ -64,7 +132,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU
@@ -240,7 +308,7 @@ class LeannBuilder:
raise ValueError("No chunks added.")
if self.dimensions is None:
self.dimensions = len(
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode)[0]
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
)
path = Path(index_path)
index_dir = path.parent
@@ -267,7 +335,7 @@ class LeannBuilder:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.embedding_mode
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}

View File

@@ -200,7 +200,27 @@ class EmbeddingServerManager:
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if not model_matches:
if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
)
return True
else:
print(
f"✅ Existing server already using correct model: {model_name}"
)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if not model_matches:
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
else:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):