change wecaht app split logic& merge
This commit is contained in:
@@ -18,15 +18,15 @@ from .chat import get_llm
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: List[str],
|
||||
model_name: str,
|
||||
chunks: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx'
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the embedding model
|
||||
@@ -35,7 +35,7 @@ def compute_embeddings(
|
||||
- "mlx": Use MLX backend for Apple Silicon
|
||||
- "openai": Use OpenAI embedding API
|
||||
use_server: Whether to use embedding server (True for search, False for build)
|
||||
|
||||
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
"""
|
||||
@@ -46,33 +46,41 @@ def compute_embeddings(
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
mode = "openai"
|
||||
|
||||
|
||||
if mode == "mlx":
|
||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
|
||||
return compute_embeddings_sentence_transformers(
|
||||
chunks, model_name, use_server=use_server
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
raise ValueError(
|
||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
|
||||
def compute_embeddings_sentence_transformers(
|
||||
chunks: List[str], model_name: str, use_server: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
if not use_server:
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
@@ -81,49 +89,55 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str,
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
||||
print(
|
||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
def _compute_embeddings_sentence_transformers_direct(
|
||||
chunks: List[str], model_name: str
|
||||
) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
@@ -164,16 +178,18 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
raise RuntimeError(
|
||||
"openai not available. Install with: uv pip install openai"
|
||||
) from e
|
||||
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
|
||||
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||
)
|
||||
|
||||
# OpenAI has a limit on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
@@ -191,18 +207,17 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
batch_chunks = chunks[i:i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch_chunks
|
||||
)
|
||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}")
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -345,7 +360,12 @@ class LeannBuilder:
|
||||
raise ValueError("No chunks added.")
|
||||
if self.dimensions is None:
|
||||
self.dimensions = len(
|
||||
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
|
||||
compute_embeddings(
|
||||
["dummy"],
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
)[0]
|
||||
)
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -414,6 +434,129 @@ class LeannBuilder:
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
|
||||
"""
|
||||
Build an index from pre-computed embeddings stored in a pickle file.
|
||||
|
||||
Args:
|
||||
index_path: Path where the index will be saved
|
||||
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
|
||||
"""
|
||||
# Load pre-computed embeddings
|
||||
with open(embeddings_file, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
if not isinstance(data, tuple) or len(data) != 2:
|
||||
raise ValueError(
|
||||
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
|
||||
)
|
||||
|
||||
ids, embeddings = data
|
||||
|
||||
if not isinstance(embeddings, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be numpy array, got {type(embeddings)}"
|
||||
)
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
|
||||
)
|
||||
|
||||
# Validate/set dimensions
|
||||
embedding_dim = embeddings.shape[1]
|
||||
if self.dimensions is None:
|
||||
self.dimensions = embedding_dim
|
||||
elif self.dimensions != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||
)
|
||||
|
||||
# Ensure we have text data for each embedding
|
||||
if len(self.chunks) != len(ids):
|
||||
# If no text chunks provided, create placeholder text entries
|
||||
if not self.chunks:
|
||||
print("No text chunks provided, creating placeholder entries...")
|
||||
for id_val in ids:
|
||||
self.add_text(
|
||||
f"Document {id_val}",
|
||||
metadata={"id": str(id_val), "from_embeddings": True},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
|
||||
)
|
||||
|
||||
# Build file structure
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||
|
||||
# Write passages and create offset map
|
||||
offset_map = {}
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
for chunk in self.chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk["metadata"],
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
# Build the vector index using precomputed embeddings
|
||||
string_ids = [str(id_val) for id_val in ids]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path)
|
||||
|
||||
# Create metadata file
|
||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||
meta_data = {
|
||||
"version": "1.0",
|
||||
"backend_name": self.backend_name,
|
||||
"embedding_model": self.embedding_model,
|
||||
"dimensions": self.dimensions,
|
||||
"backend_kwargs": self.backend_kwargs,
|
||||
"embedding_mode": self.embedding_mode,
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
}
|
||||
],
|
||||
"built_from_precomputed_embeddings": True,
|
||||
"embeddings_source": str(embeddings_file),
|
||||
}
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = is_compact and is_recompute
|
||||
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
@@ -425,7 +568,9 @@ class LeannSearcher:
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
# Backward compatibility with use_mlx
|
||||
if self.meta_data.get("use_mlx", False):
|
||||
self.embedding_mode = "mlx"
|
||||
@@ -457,6 +602,7 @@ class LeannSearcher:
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
@@ -556,7 +702,7 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
ans=self.llm.ask(prompt, **llm_kwargs)
|
||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
|
||||
Reference in New Issue
Block a user