Merge branch 'main' of github.com:yichuan520030910320/LEANN-RAG
This commit is contained in:
@@ -20,7 +20,8 @@ from .chat import get_llm
|
||||
def compute_embeddings(
|
||||
chunks: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers"
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -32,6 +33,7 @@ def compute_embeddings(
|
||||
- "sentence-transformers": Use sentence-transformers library (default)
|
||||
- "mlx": Use MLX backend for Apple Silicon
|
||||
- "openai": Use OpenAI embedding API
|
||||
use_server: Whether to use embedding server (True for search, False for build)
|
||||
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
@@ -45,13 +47,79 @@ def compute_embeddings(
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(chunks, model_name)
|
||||
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers library."""
|
||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
if not use_server:
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
@@ -64,7 +132,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
@@ -240,7 +308,7 @@ class LeannBuilder:
|
||||
raise ValueError("No chunks added.")
|
||||
if self.dimensions is None:
|
||||
self.dimensions = len(
|
||||
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode)[0]
|
||||
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
|
||||
)
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -267,7 +335,7 @@ class LeannBuilder:
|
||||
pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed, self.embedding_model, self.embedding_mode
|
||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
|
||||
@@ -200,7 +200,27 @@ class EmbeddingServerManager:
|
||||
|
||||
# Check model compatibility
|
||||
model_matches = _check_server_model(self.server_port, model_name)
|
||||
if not model_matches:
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Still check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Server process exists but port not responding - restart
|
||||
print("⚠️ Server process exists but not responding. Restarting...")
|
||||
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
|
||||
|
||||
# Check model compatibility first
|
||||
model_matches = _check_server_model(port, model_name)
|
||||
if not model_matches:
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
|
||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||
)
|
||||
print(f"✅ Successfully updated server model to: {model_name}")
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Check meta path compatibility if provided
|
||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||
|
||||
Reference in New Issue
Block a user