Compare commits
8 Commits
feature/co
...
perf-build
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d084f678c | ||
|
|
54155e8b10 | ||
|
|
f47f76d6d7 | ||
|
|
1dc3923b53 | ||
|
|
7e226a51c9 | ||
|
|
f4998bb316 | ||
|
|
7522de1d41 | ||
|
|
15f8bd1cc9 |
@@ -48,7 +48,7 @@ git submodule update --init --recursive
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
|
||||
@@ -61,7 +61,7 @@ uv sync --extra diskann
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||
# DiskANN will handle everything itself, including compiling Python bindings
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
|
||||
@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
|
||||
@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Load label map
|
||||
passages_dir = Path(meta_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
if label_map_file.exists():
|
||||
import pickle
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
||||
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||
|
||||
class LazyPassageLoader(SimplePassageLoader):
|
||||
def __init__(self, passage_manager, label_map):
|
||||
def __init__(self, passage_manager):
|
||||
self.passage_manager = passage_manager
|
||||
self.label_map = label_map
|
||||
# Initialize parent with empty data
|
||||
super().__init__({})
|
||||
|
||||
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
"""Get passage by ID with lazy loading"""
|
||||
try:
|
||||
int_id = int(passage_id)
|
||||
if int_id in self.label_map:
|
||||
string_id = self.label_map[int_id]
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
string_id = str(int_id)
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_map)
|
||||
return len(self.passage_manager.global_offset_map)
|
||||
|
||||
def keys(self):
|
||||
return self.label_map.keys()
|
||||
return self.passage_manager.global_offset_map.keys()
|
||||
|
||||
loader = LazyPassageLoader(passage_manager, label_map)
|
||||
loader = LazyPassageLoader(passage_manager)
|
||||
loader._meta_path = meta_file
|
||||
return loader
|
||||
|
||||
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
# Load passages directly by their sequential IDs
|
||||
passages_data = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
passages_data[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
|
||||
@@ -8,11 +8,11 @@ version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# 关键:简化的 CMake 路径
|
||||
# Key: simplified CMake path
|
||||
cmake.source-dir = "third_party/DiskANN"
|
||||
# 关键:Python 包在根目录,路径完全匹配
|
||||
# Key: Python package in root directory, paths match exactly
|
||||
wheel.packages = ["leann_backend_diskann"]
|
||||
# 使用默认的 redirect 模式
|
||||
# Use default redirect mode
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# 最终简化版
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
set(CMAKE_C_COMPILER_WORKS 1)
|
||||
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
@@ -11,15 +12,9 @@ if(APPLE)
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
endif()
|
||||
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
# Use system ZeroMQ instead of building from source
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
@@ -29,6 +24,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
# Faiss configuration - streamlined build
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
@@ -36,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||
|
||||
# Disable additional SIMD versions to speed up compilation
|
||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# Additional optimization options from INSTALL.md
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||
|
||||
# Avoid building demos and benchmarks
|
||||
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# NEW: Tell Faiss to only build the generic version
|
||||
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
@@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||
if metric_enum is None:
|
||||
@@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load label mapping
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
||||
|
||||
with open(label_map_file, "rb") as f:
|
||||
self.label_map = pickle.load(f)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,4 +15,8 @@ wheel.packages = ["leann_backend_hnsw"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ff22e2c86b
@@ -15,5 +15,8 @@ dependencies = [
|
||||
"tqdm>=4.60.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -9,9 +9,6 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .chat import get_llm
|
||||
@@ -22,7 +19,7 @@ def compute_embeddings(
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||
port: int = 5557,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -39,251 +36,60 @@ def compute_embeddings(
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
"""
|
||||
# Override mode for backward compatibility
|
||||
if use_mlx:
|
||||
mode = "mlx"
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
mode = "openai"
|
||||
|
||||
if mode == "mlx":
|
||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
chunks, model_name, use_server=use_server
|
||||
)
|
||||
if use_server:
|
||||
# Use embedding server (for search/query)
|
||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||
# Use direct computation (for build_index)
|
||||
from .embedding_compute import (
|
||||
compute_embeddings as compute_embeddings_direct,
|
||||
)
|
||||
|
||||
return compute_embeddings_direct(
|
||||
chunks,
|
||||
model_name,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
chunks: List[str], model_name: str, use_server: bool = True
|
||||
def compute_embeddings_via_server(
|
||||
chunks: List[str], model_name: str, port: int
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
if not use_server:
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(
|
||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(
|
||||
chunks: List[str], model_name: str
|
||||
) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
||||
) from e
|
||||
|
||||
# Load model using sentence-transformers
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
|
||||
# Generate embeddings
|
||||
# give use an warning if OOM here means we need to turn down the batch size
|
||||
embeddings = model.encode(
|
||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
||||
)
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using OpenAI API."""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"openai not available. Install with: uv pip install openai"
|
||||
) from e
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||
)
|
||||
|
||||
# OpenAI has a limit on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(chunks), max_batch_size)
|
||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
||||
except ImportError:
|
||||
# Fallback without progress bar
|
||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
id: str
|
||||
@@ -344,14 +150,12 @@ class LeannBuilder:
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.backend_kwargs = backend_kwargs
|
||||
if 'mlx' in self.embedding_model:
|
||||
self.embedding_mode = "mlx"
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||
self.chunks.append(chunk_data)
|
||||
|
||||
@@ -377,10 +181,13 @@ class LeannBuilder:
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||
|
||||
chunk_iterator = tqdm(
|
||||
self.chunks, desc="Writing passages", unit="chunk"
|
||||
)
|
||||
except ImportError:
|
||||
chunk_iterator = self.chunks
|
||||
|
||||
|
||||
for chunk in chunk_iterator:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
@@ -398,7 +205,11 @@ class LeannBuilder:
|
||||
pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
port=5557,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
|
||||
287
packages/leann-core/src/leann/cli.py
Normal file
287
packages/leann-core/src/leann/cli.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
return meta_file.exists()
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="leann",
|
||||
description="LEANN - Local Enhanced AI Navigation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||
build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
|
||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
search_parser.add_argument("index_name", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5)
|
||||
search_parser.add_argument("--complexity", type=int, default=64)
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
|
||||
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
return parser
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
|
||||
if not index_dirs:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print(f"\nUsage:")
|
||||
print(f" leann search {example_name} \"your query\"")
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
is_recompute=args.recompute,
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||
return
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
print(f" {result.text[:200]}...")
|
||||
print()
|
||||
|
||||
async def ask_questions(self, args):
|
||||
index_name = args.index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||
return
|
||||
|
||||
print(f"Starting chat with index '{index_name}'...")
|
||||
print(f"Using {args.model} ({args.llm})")
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
if args.interactive:
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if args.command == "list":
|
||||
self.list_indexes()
|
||||
elif args.command == "build":
|
||||
await self.build_index(args)
|
||||
elif args.command == "search":
|
||||
await self.search_documents(args)
|
||||
elif args.command == "ask":
|
||||
await self.ask_questions(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
def main():
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
cli = LeannCLI()
|
||||
asyncio.run(cli.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Unified embedding computation module
|
||||
Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: List[str], model_name: str, mode: str = "sentence-transformers"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(texts, model_name)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
use_fp16: bool = True,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure consistency with original embedding_server
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: SentenceTransformer model name
|
||||
use_fp16: Whether to use FP16 precision
|
||||
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
)
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"INFO: Using device: {device}")
|
||||
|
||||
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True, # Skip weight initialization checks for faster loading
|
||||
}
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True, # Use fast tokenizer for better runtime performance
|
||||
}
|
||||
|
||||
# Load SentenceTransformer (try local first, then network)
|
||||
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
||||
|
||||
try:
|
||||
# Try local loading (avoid network delays)
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
print("✅ Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
print(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
print("✅ Model loaded successfully! (network + optimized)")
|
||||
|
||||
# Apply additional optimizations (if supported)
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
||||
)
|
||||
|
||||
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
||||
print("INFO: Starting embedding computation...")
|
||||
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=False, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False, # Keep consistent with original API behavior
|
||||
device=device,
|
||||
)
|
||||
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
# Validate results
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
raise RuntimeError(
|
||||
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(texts), max_batch_size)
|
||||
batch_iterator = tqdm(
|
||||
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback when tqdm is not available
|
||||
batch_iterator = range(0, len(texts), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_texts = texts[i : i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Batch {i} failed: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(
|
||||
chunks: List[str], model_name: str, batch_size: int = 16
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
batch_iterator = tqdm(
|
||||
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||
)
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
@@ -4,11 +4,10 @@ import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import zmq
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import select
|
||||
import psutil
|
||||
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
@@ -17,151 +16,135 @@ def _check_port(port: int) -> bool:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
||||
def _check_process_matches_config(
|
||||
port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the existing server on the port is using the correct meta file.
|
||||
Returns True if the server has the right meta path, False otherwise.
|
||||
Check if the process using the port matches our expected model and passages file.
|
||||
Returns True if matches, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||
if not _is_process_listening_on_port(proc, port):
|
||||
continue
|
||||
|
||||
# Send a special control message to query the server's meta path
|
||||
control_request = ["__QUERY_META_PATH__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
cmdline = proc.info["cmdline"]
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the meta path and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_meta_path = response[0]
|
||||
# Normalize paths for comparison
|
||||
expected_path = Path(expected_meta_path).resolve()
|
||||
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
||||
return server_path == expected_path
|
||||
return _check_cmdline_matches_config(
|
||||
cmdline, port, expected_model, expected_passages_file
|
||||
)
|
||||
|
||||
print(f"DEBUG: No process found listening on port {port}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
||||
print(f"WARNING: Could not check process on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's meta path.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||
"""Check if a process is listening on the given port."""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the meta path
|
||||
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
connections = proc.net_connections()
|
||||
for conn in connections:
|
||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return False
|
||||
|
||||
|
||||
def _check_server_model(port: int, expected_model: str) -> bool:
|
||||
def _check_cmdline_matches_config(
|
||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""Check if command line matches our expected configuration."""
|
||||
cmdline_str = " ".join(cmdline)
|
||||
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
||||
|
||||
# Check if it's our embedding server
|
||||
is_embedding_server = any(
|
||||
server_type in cmdline_str
|
||||
for server_type in [
|
||||
"embedding_server",
|
||||
"leann_backend_diskann.embedding_server",
|
||||
"leann_backend_hnsw.hnsw_embedding_server",
|
||||
]
|
||||
)
|
||||
|
||||
if not is_embedding_server:
|
||||
print(f"DEBUG: Process on port {port} is not our embedding server")
|
||||
return False
|
||||
|
||||
# Check model name
|
||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||
|
||||
# Check passages file if provided
|
||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||
|
||||
result = model_matches and passages_matches
|
||||
print(
|
||||
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||
"""Check if the command line contains the expected model."""
|
||||
if "--model-name" not in cmdline:
|
||||
return False
|
||||
|
||||
model_idx = cmdline.index("--model-name")
|
||||
if model_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_model = cmdline[model_idx + 1]
|
||||
return actual_model == expected_model
|
||||
|
||||
|
||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||
"""Check if the command line contains the expected passages file."""
|
||||
if "--passages-file" not in cmdline:
|
||||
return False # Expected but not found
|
||||
|
||||
passages_idx = cmdline.index("--passages-file")
|
||||
if passages_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_passages = cmdline[passages_idx + 1]
|
||||
expected_path = Path(expected_passages_file).resolve()
|
||||
actual_path = Path(actual_passages).resolve()
|
||||
return actual_path == expected_path
|
||||
|
||||
|
||||
def _find_compatible_port_or_next_available(
|
||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if the existing server on the port is using the correct embedding model.
|
||||
Returns True if the server has the right model, False otherwise.
|
||||
Find a port that either has a compatible server or is available.
|
||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
for port in range(start_port, start_port + max_attempts):
|
||||
if not _check_port(port):
|
||||
# Port is available
|
||||
return port, False
|
||||
|
||||
# Send a special control message to query the server's model
|
||||
control_request = ["__QUERY_MODEL__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
# Port is in use, check if it's compatible
|
||||
if _check_process_matches_config(port, model_name, passages_file):
|
||||
print(f"✅ Found compatible server on port {port}")
|
||||
return port, True
|
||||
else:
|
||||
print(f"⚠️ Port {port} has incompatible server, trying next port...")
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the model name and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_model = response[0]
|
||||
return server_model == expected_model
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server model on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_model(port: int, new_model: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's embedding model.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the model
|
||||
control_request = ["__UPDATE_MODEL__", new_model]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server model on port {port}: {e}")
|
||||
return False
|
||||
raise RuntimeError(
|
||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
"""
|
||||
|
||||
def __init__(self, backend_module_name: str):
|
||||
@@ -175,210 +158,162 @@ class EmbeddingServerManager:
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
atexit.register(self.stop_server)
|
||||
self._atexit_registered = False
|
||||
|
||||
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
|
||||
def start_server(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Starts the embedding server process.
|
||||
|
||||
Args:
|
||||
port (int): The ZMQ port for the server.
|
||||
port (int): The preferred ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
||||
**kwargs: Additional arguments for the server.
|
||||
|
||||
Returns:
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
tuple[bool, int]: (success, actual_port_used)
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
# Even if we have a running process, check if model/meta path match
|
||||
if self.server_port is not None:
|
||||
port_in_use = _check_port(self.server_port)
|
||||
if port_in_use:
|
||||
print(
|
||||
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
||||
)
|
||||
passages_file = kwargs.get("passages_file")
|
||||
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||
|
||||
# Check model compatibility
|
||||
model_matches = _check_server_model(self.server_port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Still check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(self.server_port, model_name):
|
||||
print(
|
||||
"❌ Failed to update existing server model. Restarting server..."
|
||||
)
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
print(
|
||||
f"✅ Successfully updated existing server model to: {model_name}"
|
||||
)
|
||||
# Check if we have a compatible running server
|
||||
if self._has_compatible_running_server(model_name, passages_file):
|
||||
assert self.server_port is not None, (
|
||||
"a compatible running server should set server_port"
|
||||
)
|
||||
return True, self.server_port
|
||||
|
||||
# Also check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
# Find available port (compatible or free)
|
||||
try:
|
||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||
port, model_name, passages_file
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ {e}")
|
||||
return False, port
|
||||
|
||||
return True
|
||||
else:
|
||||
# Server process exists but port not responding - restart
|
||||
print("⚠️ Server process exists but not responding. Restarting...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
# No port stored - restart
|
||||
print("⚠️ No port information stored. Restarting server...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
if is_compatible:
|
||||
print(f"✅ Using existing compatible server on port {actual_port}")
|
||||
self.server_port = actual_port
|
||||
self.server_process = None # We don't own this process
|
||||
return True, actual_port
|
||||
|
||||
if _check_port(port):
|
||||
# Port is in use, check if it's using the correct meta file and model
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if actual_port != port:
|
||||
print(f"⚠️ Using port {actual_port} instead of {port}")
|
||||
|
||||
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
||||
# Start new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
# Check model compatibility first
|
||||
model_matches = _check_server_model(port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(port, model_name):
|
||||
raise RuntimeError(
|
||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||
)
|
||||
print(f"✅ Successfully updated server model to: {model_name}")
|
||||
def _has_compatible_running_server(
|
||||
self, model_name: str, passages_file: str
|
||||
) -> bool:
|
||||
"""Check if we have a compatible running server."""
|
||||
if not (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
):
|
||||
return False
|
||||
|
||||
# Check meta path compatibility if provided
|
||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||
meta_matches = _check_server_meta_path(port, str(passages_file))
|
||||
if not meta_matches:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
||||
)
|
||||
if not _update_server_meta_path(port, str(passages_file)):
|
||||
raise RuntimeError(
|
||||
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
||||
)
|
||||
print(
|
||||
f"✅ Successfully updated server meta path to: {passages_file}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
||||
)
|
||||
|
||||
print(f"✅ Server on port {port} is compatible and ready to use.")
|
||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||
print(
|
||||
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
|
||||
)
|
||||
return True
|
||||
|
||||
print(
|
||||
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
||||
)
|
||||
print("⚠️ Existing server process is incompatible. Should start a new server.")
|
||||
return False
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
print(f"INFO: Starting embedding server on port {port}...")
|
||||
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
# Add extra arguments for specific backends
|
||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
||||
command.extend(["--disable-warmup"])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1, # Line buffered
|
||||
universal_newlines=True,
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print(
|
||||
"❌ ERROR: Server process terminated unexpectedly during startup."
|
||||
)
|
||||
self._print_recent_output()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(
|
||||
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
||||
)
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
|
||||
def _build_server_command(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> list:
|
||||
"""Build the command to start the embedding server."""
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
if kwargs.get("passages_file"):
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
|
||||
return command
|
||||
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Command: {' '.join(command)}")
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
if not self._atexit_registered:
|
||||
# Use a lambda to avoid issues with bound methods
|
||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready."""
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is ready!")
|
||||
threading.Thread(target=self._log_monitor, daemon=True).start()
|
||||
return True, port
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: Server terminated during startup.")
|
||||
self._print_recent_output()
|
||||
return False, port
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False, port
|
||||
|
||||
def _print_recent_output(self):
|
||||
"""Print any recent output from the server process."""
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
return
|
||||
try:
|
||||
# Read any available output
|
||||
|
||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||
output = self.server_process.stdout.read()
|
||||
if output:
|
||||
@@ -404,17 +339,26 @@ class EmbeddingServerManager:
|
||||
|
||||
def stop_server(self):
|
||||
"""Stops the embedding server process if it's running."""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
if not self.server_process:
|
||||
return
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
# Process already terminated
|
||||
self.server_process = None
|
||||
return
|
||||
|
||||
print(
|
||||
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
||||
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
"WARNING: Server process did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
self.server_process.kill()
|
||||
|
||||
self.server_process = None
|
||||
|
||||
@@ -7,30 +7,37 @@ import importlib.metadata
|
||||
if TYPE_CHECKING:
|
||||
from leann.interface import LeannBackendFactoryInterface
|
||||
|
||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||
|
||||
|
||||
def register_backend(name: str):
|
||||
"""A decorator to register a new backend class."""
|
||||
|
||||
def decorator(cls):
|
||||
print(f"INFO: Registering backend '{name}'")
|
||||
BACKEND_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def autodiscover_backends():
|
||||
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||
print("INFO: Starting backend auto-discovery...")
|
||||
# print("INFO: Starting backend auto-discovery...")
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata['name']
|
||||
if dist_name.startswith('leann-backend-'):
|
||||
backend_module_name = dist_name.replace('-', '_')
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||
|
||||
for backend_module_name in sorted(
|
||||
discovered_backends
|
||||
): # sort for deterministic loading
|
||||
try:
|
||||
importlib.import_module(backend_module_name)
|
||||
# Registration message is printed by the decorator
|
||||
except ImportError as e:
|
||||
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
print("INFO: Backend auto-discovery finished.")
|
||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
pass
|
||||
# print("INFO: Backend auto-discovery finished.")
|
||||
|
||||
@@ -43,8 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||
)
|
||||
|
||||
self.label_map = self._load_label_map()
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name
|
||||
)
|
||||
@@ -58,17 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_label_map(self) -> Dict[int, str]:
|
||||
"""Loads the mapping from integer IDs to string IDs."""
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
with open(label_map_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: int, **kwargs
|
||||
) -> None:
|
||||
) -> int:
|
||||
"""
|
||||
Ensures the embedding server is running if recompute is needed.
|
||||
This is a helper for subclasses.
|
||||
@@ -79,8 +69,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
)
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
passages_file=passages_source_file,
|
||||
@@ -89,7 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
raise RuntimeError(
|
||||
f"Failed to start embedding server on port {actual_port}"
|
||||
)
|
||||
|
||||
return actual_port
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
@@ -106,12 +100,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Try to use embedding server if available and requested
|
||||
if (
|
||||
use_server_if_available
|
||||
and self.embedding_server_manager
|
||||
and self.embedding_server_manager.server_process
|
||||
):
|
||||
if use_server_if_available:
|
||||
try:
|
||||
# Ensure we have a server with passages_file for compatibility
|
||||
passages_source_file = (
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
)
|
||||
zmq_port = self._ensure_server_running(
|
||||
str(passages_source_file), zmq_port
|
||||
)
|
||||
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
0:1
|
||||
] # Return (1, D) shape
|
||||
|
||||
@@ -35,6 +35,7 @@ dependencies = [
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
"mlx>=0.26.3",
|
||||
"mlx-lm>=0.26.0",
|
||||
"psutil>=5.8.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
16
uv.lock
generated
16
uv.lock
generated
@@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" }
|
||||
dependencies = [
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }]
|
||||
requires-dist = [
|
||||
{ name = "numpy", specifier = ">=1.20.0" },
|
||||
{ name = "tqdm", specifier = ">=4.60.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "leann-workspace"
|
||||
@@ -1851,7 +1855,6 @@ dependencies = [
|
||||
{ name = "flask" },
|
||||
{ name = "flask-compress" },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "leann-backend-diskann" },
|
||||
{ name = "leann-backend-hnsw" },
|
||||
{ name = "leann-core" },
|
||||
{ name = "llama-index" },
|
||||
@@ -1867,6 +1870,7 @@ dependencies = [
|
||||
{ name = "ollama" },
|
||||
{ name = "openai" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pypdf2" },
|
||||
{ name = "requests" },
|
||||
{ name = "sentence-transformers" },
|
||||
@@ -1884,6 +1888,9 @@ dev = [
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
diskann = [
|
||||
{ name = "leann-backend-diskann" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
@@ -1896,7 +1903,7 @@ requires-dist = [
|
||||
{ name = "flask-compress" },
|
||||
{ name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" },
|
||||
{ name = "ipykernel", specifier = "==6.29.5" },
|
||||
{ name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" },
|
||||
{ name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" },
|
||||
{ name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" },
|
||||
{ name = "leann-core", editable = "packages/leann-core" },
|
||||
{ name = "llama-index", specifier = ">=0.12.44" },
|
||||
@@ -1912,6 +1919,7 @@ requires-dist = [
|
||||
{ name = "ollama" },
|
||||
{ name = "openai", specifier = ">=1.0.0" },
|
||||
{ name = "protobuf", specifier = "==4.25.3" },
|
||||
{ name = "psutil", specifier = ">=5.8.0" },
|
||||
{ name = "pypdf2", specifier = ">=3.0.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
|
||||
@@ -1922,7 +1930,7 @@ requires-dist = [
|
||||
{ name = "torch" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
provides-extras = ["dev", "diskann"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-cloud"
|
||||
|
||||
Reference in New Issue
Block a user