Compare commits
8 Commits
feature/en
...
perf-build
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d084f678c | ||
|
|
54155e8b10 | ||
|
|
f47f76d6d7 | ||
|
|
1dc3923b53 | ||
|
|
7e226a51c9 | ||
|
|
f4998bb316 | ||
|
|
7522de1d41 | ||
|
|
15f8bd1cc9 |
@@ -48,7 +48,7 @@ git submodule update --init --recursive
|
|||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf
|
brew install llvm libomp boost protobuf zeromq
|
||||||
export CC=$(brew --prefix llvm)/bin/clang
|
export CC=$(brew --prefix llvm)/bin/clang
|
||||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ uv sync --extra diskann
|
|||||||
|
|
||||||
**Linux (Ubuntu/Debian):**
|
**Linux (Ubuntu/Debian):**
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
uv sync
|
uv sync
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
project(leann_backend_diskann_wrapper)
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
add_subdirectory(src/third_party/DiskANN)
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
|
|||||||
@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
|
||||||
label_map_file = index_dir / "leann.labels.map"
|
|
||||||
with open(label_map_file, "wb") as f:
|
|
||||||
pickle.dump(label_map, f)
|
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[
|
[str(int_label) for int_label in batch_labels]
|
||||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
|
||||||
for int_label in batch_labels
|
|
||||||
]
|
|
||||||
for batch_labels in labels
|
for batch_labels in labels
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
finally:
|
finally:
|
||||||
sys.path.pop(0)
|
sys.path.pop(0)
|
||||||
|
|
||||||
# Load label map
|
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||||
passages_dir = Path(meta_file).parent
|
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
if label_map_file.exists():
|
|
||||||
import pickle
|
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
|
||||||
|
|
||||||
class LazyPassageLoader(SimplePassageLoader):
|
class LazyPassageLoader(SimplePassageLoader):
|
||||||
def __init__(self, passage_manager, label_map):
|
def __init__(self, passage_manager):
|
||||||
self.passage_manager = passage_manager
|
self.passage_manager = passage_manager
|
||||||
self.label_map = label_map
|
|
||||||
# Initialize parent with empty data
|
# Initialize parent with empty data
|
||||||
super().__init__({})
|
super().__init__({})
|
||||||
|
|
||||||
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
"""Get passage by ID with lazy loading"""
|
"""Get passage by ID with lazy loading"""
|
||||||
try:
|
try:
|
||||||
int_id = int(passage_id)
|
int_id = int(passage_id)
|
||||||
if int_id in self.label_map:
|
string_id = str(int_id)
|
||||||
string_id = self.label_map[int_id]
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
if passage_data and passage_data.get("text"):
|
||||||
if passage_data and passage_data.get("text"):
|
return {"text": passage_data["text"]}
|
||||||
return {"text": passage_data["text"]}
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.label_map)
|
return len(self.passage_manager.global_offset_map)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self.label_map.keys()
|
return self.passage_manager.global_offset_map.keys()
|
||||||
|
|
||||||
loader = LazyPassageLoader(passage_manager, label_map)
|
loader = LazyPassageLoader(passage_manager)
|
||||||
loader._meta_path = meta_file
|
loader._meta_path = meta_file
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
|||||||
if not passages_file.endswith('.jsonl'):
|
if not passages_file.endswith('.jsonl'):
|
||||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||||
|
|
||||||
# Load label map (int -> string_id)
|
# Load passages directly by their sequential IDs
|
||||||
passages_dir = Path(passages_file).parent
|
passages_data = {}
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
label_map = {}
|
|
||||||
if label_map_file.exists():
|
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
# Load passages by string ID
|
|
||||||
string_id_passages = {}
|
|
||||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
if line.strip():
|
if line.strip():
|
||||||
passage = json.loads(line)
|
passage = json.loads(line)
|
||||||
string_id_passages[passage['id']] = passage['text']
|
passages_data[passage['id']] = passage['text']
|
||||||
|
|
||||||
# Create int ID -> text mapping using label map
|
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
||||||
passages_data = {}
|
|
||||||
for int_id, string_id in label_map.items():
|
|
||||||
if string_id in string_id_passages:
|
|
||||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
|
||||||
else:
|
|
||||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
|
||||||
|
|
||||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
|
||||||
return SimplePassageLoader(passages_data)
|
return SimplePassageLoader(passages_data)
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
def create_embedding_server_thread(
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ version = "0.1.0"
|
|||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# 关键:简化的 CMake 路径
|
# Key: simplified CMake path
|
||||||
cmake.source-dir = "third_party/DiskANN"
|
cmake.source-dir = "third_party/DiskANN"
|
||||||
# 关键:Python 包在根目录,路径完全匹配
|
# Key: Python package in root directory, paths match exactly
|
||||||
wheel.packages = ["leann_backend_diskann"]
|
wheel.packages = ["leann_backend_diskann"]
|
||||||
# 使用默认的 redirect 模式
|
# Use default redirect mode
|
||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# 最终简化版
|
|
||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.24)
|
||||||
project(leann_backend_hnsw_wrapper)
|
project(leann_backend_hnsw_wrapper)
|
||||||
|
set(CMAKE_C_COMPILER_WORKS 1)
|
||||||
|
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||||
|
|
||||||
# Set OpenMP path for macOS
|
# Set OpenMP path for macOS
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
@@ -11,15 +12,9 @@ if(APPLE)
|
|||||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Build ZeroMQ from source
|
# Use system ZeroMQ instead of building from source
|
||||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
find_package(PkgConfig REQUIRED)
|
||||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
|
||||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
|
||||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
|
||||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
|
||||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
|
||||||
add_subdirectory(third_party/libzmq)
|
|
||||||
|
|
||||||
# Add cppzmq headers
|
# Add cppzmq headers
|
||||||
include_directories(third_party/cppzmq)
|
include_directories(third_party/cppzmq)
|
||||||
@@ -29,6 +24,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
|||||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||||
include_directories(third_party/msgpack-c/include)
|
include_directories(third_party/msgpack-c/include)
|
||||||
|
|
||||||
|
# Faiss configuration - streamlined build
|
||||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||||
@@ -36,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
|
# Disable additional SIMD versions to speed up compilation
|
||||||
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# Additional optimization options from INSTALL.md
|
||||||
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||||
|
|
||||||
|
# Avoid building demos and benchmarks
|
||||||
|
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# NEW: Tell Faiss to only build the generic version
|
||||||
|
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||||
|
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
add_subdirectory(third_party/faiss)
|
add_subdirectory(third_party/faiss)
|
||||||
@@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
|
||||||
label_map_file = index_dir / "leann.labels.map"
|
|
||||||
with open(label_map_file, "wb") as f:
|
|
||||||
pickle.dump(label_map, f)
|
|
||||||
|
|
||||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
@@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
# Load label mapping
|
|
||||||
label_map_file = self.index_dir / "leann.labels.map"
|
|
||||||
if not label_map_file.exists():
|
|
||||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
|
||||||
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
self.label_map = pickle.load(f)
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[
|
[str(int_label) for int_label in batch_labels]
|
||||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
|
||||||
for int_label in batch_labels
|
|
||||||
]
|
|
||||||
for batch_labels in labels
|
for batch_labels in labels
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -16,3 +16,7 @@ editable.mode = "redirect"
|
|||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
|
||||||
|
# CMake definitions to optimize compilation
|
||||||
|
[tool.scikit-build.cmake.define]
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ff22e2c86b
@@ -15,5 +15,8 @@ dependencies = [
|
|||||||
"tqdm>=4.60.0"
|
"tqdm>=4.60.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
leann = "leann.cli:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
@@ -9,9 +9,6 @@ import numpy as np
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional, Literal
|
from typing import List, Dict, Any, Optional, Literal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import uuid
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
@@ -22,7 +19,7 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
port: int = 5557,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -39,251 +36,60 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
numpy array of embeddings
|
numpy array of embeddings
|
||||||
"""
|
"""
|
||||||
# Override mode for backward compatibility
|
if use_server:
|
||||||
if use_mlx:
|
# Use embedding server (for search/query)
|
||||||
mode = "mlx"
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||||
|
|
||||||
# Auto-detect mode based on model name if not explicitly set
|
|
||||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
|
||||||
mode = "openai"
|
|
||||||
|
|
||||||
if mode == "mlx":
|
|
||||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
|
||||||
elif mode == "openai":
|
|
||||||
return compute_embeddings_openai(chunks, model_name)
|
|
||||||
elif mode == "sentence-transformers":
|
|
||||||
return compute_embeddings_sentence_transformers(
|
|
||||||
chunks, model_name, use_server=use_server
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
# Use direct computation (for build_index)
|
||||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
from .embedding_compute import (
|
||||||
|
compute_embeddings as compute_embeddings_direct,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_embeddings_direct(
|
||||||
|
chunks,
|
||||||
|
model_name,
|
||||||
|
mode=mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(
|
def compute_embeddings_via_server(
|
||||||
chunks: List[str], model_name: str, use_server: bool = True
|
chunks: List[str], model_name: str, port: int
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks: List of text chunks to embed
|
chunks: List of text chunks to embed
|
||||||
model_name: Name of the sentence transformer model
|
model_name: Name of the sentence transformer model
|
||||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
|
||||||
"""
|
"""
|
||||||
if not use_server:
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
|
||||||
)
|
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
)
|
)
|
||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Use embedding server for sentence-transformers too
|
# Connect to embedding server
|
||||||
# This avoids loading the model twice (once in API, once in server)
|
context = zmq.Context()
|
||||||
try:
|
socket = context.socket(zmq.REQ)
|
||||||
# Import ZMQ client functionality and server manager
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
import zmq
|
|
||||||
import msgpack
|
|
||||||
import numpy as np
|
|
||||||
from .embedding_server_manager import EmbeddingServerManager
|
|
||||||
|
|
||||||
# Ensure embedding server is running
|
# Send chunks to server for embedding computation
|
||||||
port = 5557
|
request = chunks
|
||||||
server_manager = EmbeddingServerManager(
|
socket.send(msgpack.packb(request))
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
|
|
||||||
server_started = server_manager.start_server(
|
# Receive embeddings from server
|
||||||
port=port,
|
response = socket.recv()
|
||||||
model_name=model_name,
|
embeddings_list = msgpack.unpackb(response)
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
enable_warmup=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not server_started:
|
# Convert back to numpy array
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
|
|
||||||
# Connect to embedding server
|
socket.close()
|
||||||
context = zmq.Context()
|
context.term()
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send chunks to server for embedding computation
|
|
||||||
request = chunks
|
|
||||||
socket.send(msgpack.packb(request))
|
|
||||||
|
|
||||||
# Receive embeddings from server
|
|
||||||
response = socket.recv()
|
|
||||||
embeddings_list = msgpack.unpackb(response)
|
|
||||||
|
|
||||||
# Convert back to numpy array
|
|
||||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Fallback to direct sentence-transformers if server connection fails
|
|
||||||
print(
|
|
||||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
|
||||||
)
|
|
||||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_embeddings_sentence_transformers_direct(
|
|
||||||
chunks: List[str], model_name: str
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Direct sentence-transformers computation (fallback)."""
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Load model using sentence-transformers
|
|
||||||
model = SentenceTransformer(model_name)
|
|
||||||
|
|
||||||
model = model.half()
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
|
||||||
)
|
|
||||||
# use acclerater GPU or MAC GPU
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model = model.to("cuda")
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
model = model.to("mps")
|
|
||||||
|
|
||||||
# Generate embeddings
|
|
||||||
# give use an warning if OOM here means we need to turn down the batch size
|
|
||||||
embeddings = model.encode(
|
|
||||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
|
||||||
"""Computes embeddings using OpenAI API."""
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
import os
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"openai not available. Install with: uv pip install openai"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Get API key from environment
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=api_key)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# OpenAI has a limit on batch size and input length
|
|
||||||
max_batch_size = 100 # Conservative batch size
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
|
||||||
batch_range = range(0, len(chunks), max_batch_size)
|
|
||||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
|
||||||
except ImportError:
|
|
||||||
# Fallback without progress bar
|
|
||||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_chunks = chunks[i:i + max_batch_size]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
|
||||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
print(
|
|
||||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
|
||||||
)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
|
||||||
"""Computes embeddings using an MLX model."""
|
|
||||||
try:
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx_lm.utils import load
|
|
||||||
from tqdm import tqdm
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
|
||||||
model, tokenizer = load(model_name)
|
|
||||||
|
|
||||||
# Process chunks in batches with progress bar
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
|
||||||
except ImportError:
|
|
||||||
batch_iterator = range(0, len(chunks), batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_chunks = chunks[i:i + batch_size]
|
|
||||||
|
|
||||||
# Tokenize all chunks in the batch
|
|
||||||
batch_token_ids = []
|
|
||||||
for chunk in batch_chunks:
|
|
||||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
|
||||||
batch_token_ids.append(token_ids)
|
|
||||||
|
|
||||||
# Pad sequences to the same length for batch processing
|
|
||||||
max_length = max(len(ids) for ids in batch_token_ids)
|
|
||||||
padded_token_ids = []
|
|
||||||
for token_ids in batch_token_ids:
|
|
||||||
# Pad with tokenizer.pad_token_id or 0
|
|
||||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
|
||||||
padded_token_ids.append(padded)
|
|
||||||
|
|
||||||
# Convert to MLX array with batch dimension
|
|
||||||
input_ids = mx.array(padded_token_ids)
|
|
||||||
|
|
||||||
# Get embeddings for the batch
|
|
||||||
embeddings = model(input_ids)
|
|
||||||
|
|
||||||
# Mean pooling for each sequence in the batch
|
|
||||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
|
||||||
|
|
||||||
# Convert batch embeddings to numpy
|
|
||||||
for j in range(len(batch_chunks)):
|
|
||||||
pooled_list = pooled[j].tolist() # Convert to list
|
|
||||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
|
||||||
all_embeddings.append(pooled_numpy)
|
|
||||||
|
|
||||||
# Stack numpy arrays
|
|
||||||
return np.stack(all_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
id: str
|
id: str
|
||||||
@@ -344,14 +150,12 @@ class LeannBuilder:
|
|||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
if 'mlx' in self.embedding_model:
|
|
||||||
self.embedding_mode = "mlx"
|
|
||||||
self.chunks: List[Dict[str, Any]] = []
|
self.chunks: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||||
self.chunks.append(chunk_data)
|
self.chunks.append(chunk_data)
|
||||||
|
|
||||||
@@ -377,7 +181,10 @@ class LeannBuilder:
|
|||||||
with open(passages_file, "w", encoding="utf-8") as f:
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
|
||||||
|
chunk_iterator = tqdm(
|
||||||
|
self.chunks, desc="Writing passages", unit="chunk"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
chunk_iterator = self.chunks
|
chunk_iterator = self.chunks
|
||||||
|
|
||||||
@@ -398,7 +205,11 @@ class LeannBuilder:
|
|||||||
pickle.dump(offset_map, f)
|
pickle.dump(offset_map, f)
|
||||||
texts_to_embed = [c["text"] for c in self.chunks]
|
texts_to_embed = [c["text"] for c in self.chunks]
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
texts_to_embed,
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
port=5557,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
|
|||||||
287
packages/leann-core/src/leann/cli.py
Normal file
287
packages/leann-core/src/leann/cli.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
|
||||||
|
class LeannCLI:
|
||||||
|
def __init__(self):
|
||||||
|
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||||
|
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_index_path(self, index_name: str) -> str:
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
return str(index_dir / "documents.leann")
|
||||||
|
|
||||||
|
def index_exists(self, index_name: str) -> bool:
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
|
return meta_file.exists()
|
||||||
|
|
||||||
|
def create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="leann",
|
||||||
|
description="LEANN - Local Enhanced AI Navigation",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
leann build my-docs --docs ./documents # Build index named my-docs
|
||||||
|
leann search my-docs "query" # Search in my-docs index
|
||||||
|
leann ask my-docs "question" # Ask my-docs index
|
||||||
|
leann list # List all stored indexes
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
|
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||||
|
build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
|
||||||
|
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||||
|
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||||
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
|
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||||
|
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||||
|
|
||||||
|
# Search command
|
||||||
|
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||||
|
search_parser.add_argument("index_name", help="Index name")
|
||||||
|
search_parser.add_argument("query", help="Search query")
|
||||||
|
search_parser.add_argument("--top-k", type=int, default=5)
|
||||||
|
search_parser.add_argument("--complexity", type=int, default=64)
|
||||||
|
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
|
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
|
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
|
search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||||
|
|
||||||
|
# Ask command
|
||||||
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
|
ask_parser.add_argument("index_name", help="Index name")
|
||||||
|
ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
|
||||||
|
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||||
|
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||||
|
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||||
|
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||||
|
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
|
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
|
ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||||
|
|
||||||
|
# List command
|
||||||
|
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def list_indexes(self):
|
||||||
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
|
if not self.indexes_dir.exists():
|
||||||
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
|
if not index_dirs:
|
||||||
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
|
for i, index_dir in enumerate(index_dirs, 1):
|
||||||
|
index_name = index_dir.name
|
||||||
|
status = "✓" if self.index_exists(index_name) else "✗"
|
||||||
|
|
||||||
|
print(f" {i}. {index_name} [{status}]")
|
||||||
|
if self.index_exists(index_name):
|
||||||
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
|
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
|
||||||
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
|
if index_dirs:
|
||||||
|
example_name = index_dirs[0].name
|
||||||
|
print(f"\nUsage:")
|
||||||
|
print(f" leann search {example_name} \"your query\"")
|
||||||
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
|
def load_documents(self, docs_dir: str):
|
||||||
|
print(f"Loading documents from {docs_dir}...")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
docs_dir,
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
async def build_index(self, args):
|
||||||
|
docs_dir = args.docs
|
||||||
|
index_name = args.index_name
|
||||||
|
index_dir = self.indexes_dir / index_name
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if index_dir.exists() and not args.force:
|
||||||
|
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||||
|
return
|
||||||
|
|
||||||
|
all_texts = self.load_documents(docs_dir)
|
||||||
|
if not all_texts:
|
||||||
|
print("No documents found")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.complexity,
|
||||||
|
is_compact=args.compact,
|
||||||
|
is_recompute=args.recompute,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index built at {index_path}")
|
||||||
|
|
||||||
|
async def search_documents(self, args):
|
||||||
|
index_name = args.index_name
|
||||||
|
query = args.query
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if not self.index_exists(index_name):
|
||||||
|
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||||
|
return
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
results = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"{i}. Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:200]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
async def ask_questions(self, args):
|
||||||
|
index_name = args.index_name
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
|
if not self.index_exists(index_name):
|
||||||
|
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Starting chat with index '{index_name}'...")
|
||||||
|
print(f"Using {args.model} ({args.llm})")
|
||||||
|
|
||||||
|
llm_config = {"type": args.llm, "model": args.model}
|
||||||
|
if args.llm == "ollama":
|
||||||
|
llm_config["host"] = args.host
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
if args.interactive:
|
||||||
|
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ").strip()
|
||||||
|
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
user_input,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
else:
|
||||||
|
query = input("Enter your question: ").strip()
|
||||||
|
if query:
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
async def run(self, args=None):
|
||||||
|
parser = self.create_parser()
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "list":
|
||||||
|
self.list_indexes()
|
||||||
|
elif args.command == "build":
|
||||||
|
await self.build_index(args)
|
||||||
|
elif args.command == "search":
|
||||||
|
await self.search_documents(args)
|
||||||
|
elif args.command == "ask":
|
||||||
|
await self.ask_questions(args)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import dotenv
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
asyncio.run(cli.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
Unified embedding computation module
|
||||||
|
Consolidates all embedding computation logic using SentenceTransformer
|
||||||
|
Preserves all optimization parameters to ensure performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings(
|
||||||
|
texts: List[str], model_name: str, mode: str = "sentence-transformers"
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Unified embedding computation entry point
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Model name
|
||||||
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
if mode == "sentence-transformers":
|
||||||
|
return compute_embeddings_sentence_transformers(texts, model_name)
|
||||||
|
elif mode == "openai":
|
||||||
|
return compute_embeddings_openai(texts, model_name)
|
||||||
|
elif mode == "mlx":
|
||||||
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_sentence_transformers(
|
||||||
|
texts: List[str],
|
||||||
|
model_name: str,
|
||||||
|
use_fp16: bool = True,
|
||||||
|
device: str = "auto",
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using SentenceTransformer
|
||||||
|
Preserves all optimization parameters to ensure consistency with original embedding_server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: SentenceTransformer model name
|
||||||
|
use_fp16: Whether to use FP16 precision
|
||||||
|
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# Auto-detect device
|
||||||
|
if device == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
print(f"INFO: Using device: {device}")
|
||||||
|
|
||||||
|
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
||||||
|
model_kwargs = {
|
||||||
|
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"_fast_init": True, # Skip weight initialization checks for faster loading
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"use_fast": True, # Use fast tokenizer for better runtime performance
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load SentenceTransformer (try local first, then network)
|
||||||
|
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try local loading (avoid network delays)
|
||||||
|
model_kwargs["local_files_only"] = True
|
||||||
|
tokenizer_kwargs["local_files_only"] = True
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
print("✅ Model loaded successfully! (local + optimized)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Local loading failed ({e}), trying network download...")
|
||||||
|
# Fallback to network loading
|
||||||
|
model_kwargs["local_files_only"] = False
|
||||||
|
tokenizer_kwargs["local_files_only"] = False
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
print("✅ Model loaded successfully! (network + optimized)")
|
||||||
|
|
||||||
|
# Apply additional optimizations (if supported)
|
||||||
|
if use_fp16 and device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
||||||
|
print("INFO: Starting embedding computation...")
|
||||||
|
|
||||||
|
embeddings = model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=batch_size,
|
||||||
|
show_progress_bar=False, # Don't show progress bar in server environment
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False, # Keep consistent with original API behavior
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate results
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||||
|
"""Compute embeddings using OpenAI API"""
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
import os
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key=api_key)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI has limits on batch size and input length
|
||||||
|
max_batch_size = 100 # Conservative batch size
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||||
|
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
print(
|
||||||
|
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_mlx(
|
||||||
|
chunks: List[str], model_name: str, batch_size: int = 16
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Computes embeddings using an MLX model."""
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.utils import load
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
model, tokenizer = load(model_name)
|
||||||
|
|
||||||
|
# Process chunks in batches with progress bar
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
batch_iterator = range(0, len(chunks), batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
|
|
||||||
|
# Tokenize all chunks in the batch
|
||||||
|
batch_token_ids = []
|
||||||
|
for chunk in batch_chunks:
|
||||||
|
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||||
|
batch_token_ids.append(token_ids)
|
||||||
|
|
||||||
|
# Pad sequences to the same length for batch processing
|
||||||
|
max_length = max(len(ids) for ids in batch_token_ids)
|
||||||
|
padded_token_ids = []
|
||||||
|
for token_ids in batch_token_ids:
|
||||||
|
# Pad with tokenizer.pad_token_id or 0
|
||||||
|
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||||
|
padded_token_ids.append(padded)
|
||||||
|
|
||||||
|
# Convert to MLX array with batch dimension
|
||||||
|
input_ids = mx.array(padded_token_ids)
|
||||||
|
|
||||||
|
# Get embeddings for the batch
|
||||||
|
embeddings = model(input_ids)
|
||||||
|
|
||||||
|
# Mean pooling for each sequence in the batch
|
||||||
|
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||||
|
|
||||||
|
# Convert batch embeddings to numpy
|
||||||
|
for j in range(len(batch_chunks)):
|
||||||
|
pooled_list = pooled[j].tolist() # Convert to list
|
||||||
|
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||||
|
all_embeddings.append(pooled_numpy)
|
||||||
|
|
||||||
|
# Stack numpy arrays
|
||||||
|
return np.stack(all_embeddings)
|
||||||
@@ -4,11 +4,10 @@ import atexit
|
|||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import zmq
|
|
||||||
import msgpack
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import select
|
import select
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
def _check_port(port: int) -> bool:
|
||||||
@@ -17,151 +16,135 @@ def _check_port(port: int) -> bool:
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
def _check_process_matches_config(
|
||||||
|
port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the existing server on the port is using the correct meta file.
|
Check if the process using the port matches our expected model and passages file.
|
||||||
Returns True if the server has the right meta path, False otherwise.
|
Returns True if matches, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||||
socket = context.socket(zmq.REQ)
|
if not _is_process_listening_on_port(proc, port):
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
continue
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send a special control message to query the server's meta path
|
cmdline = proc.info["cmdline"]
|
||||||
control_request = ["__QUERY_META_PATH__"]
|
if not cmdline:
|
||||||
request_bytes = msgpack.packb(control_request)
|
continue
|
||||||
socket.send(request_bytes)
|
|
||||||
|
|
||||||
# Wait for response
|
return _check_cmdline_matches_config(
|
||||||
response_bytes = socket.recv()
|
cmdline, port, expected_model, expected_passages_file
|
||||||
response = msgpack.unpackb(response_bytes)
|
)
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the response contains the meta path and if it matches
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
server_meta_path = response[0]
|
|
||||||
# Normalize paths for comparison
|
|
||||||
expected_path = Path(expected_meta_path).resolve()
|
|
||||||
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
|
||||||
return server_path == expected_path
|
|
||||||
|
|
||||||
|
print(f"DEBUG: No process found listening on port {port}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
print(f"WARNING: Could not check process on port {port}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||||
"""
|
"""Check if a process is listening on the given port."""
|
||||||
Send a control message to update the server's meta path.
|
|
||||||
Returns True if successful, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
connections = proc.net_connections()
|
||||||
socket = context.socket(zmq.REQ)
|
for conn in connections:
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
return True
|
||||||
|
|
||||||
# Send a control message to update the meta path
|
|
||||||
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
|
||||||
request_bytes = msgpack.packb(control_request)
|
|
||||||
socket.send(request_bytes)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
response = msgpack.unpackb(response_bytes)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the update was successful
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
return response[0] == "SUCCESS"
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _check_server_model(port: int, expected_model: str) -> bool:
|
def _check_cmdline_matches_config(
|
||||||
|
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if command line matches our expected configuration."""
|
||||||
|
cmdline_str = " ".join(cmdline)
|
||||||
|
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
|
# Check if it's our embedding server
|
||||||
|
is_embedding_server = any(
|
||||||
|
server_type in cmdline_str
|
||||||
|
for server_type in [
|
||||||
|
"embedding_server",
|
||||||
|
"leann_backend_diskann.embedding_server",
|
||||||
|
"leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_embedding_server:
|
||||||
|
print(f"DEBUG: Process on port {port} is not our embedding server")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check model name
|
||||||
|
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||||
|
|
||||||
|
# Check passages file if provided
|
||||||
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
|
result = model_matches and passages_matches
|
||||||
|
print(
|
||||||
|
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected model."""
|
||||||
|
if "--model-name" not in cmdline:
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_idx = cmdline.index("--model-name")
|
||||||
|
if model_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_model = cmdline[model_idx + 1]
|
||||||
|
return actual_model == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected passages file."""
|
||||||
|
if "--passages-file" not in cmdline:
|
||||||
|
return False # Expected but not found
|
||||||
|
|
||||||
|
passages_idx = cmdline.index("--passages-file")
|
||||||
|
if passages_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_passages = cmdline[passages_idx + 1]
|
||||||
|
expected_path = Path(expected_passages_file).resolve()
|
||||||
|
actual_path = Path(actual_passages).resolve()
|
||||||
|
return actual_path == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def _find_compatible_port_or_next_available(
|
||||||
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Check if the existing server on the port is using the correct embedding model.
|
Find a port that either has a compatible server or is available.
|
||||||
Returns True if the server has the right model, False otherwise.
|
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||||
"""
|
"""
|
||||||
try:
|
for port in range(start_port, start_port + max_attempts):
|
||||||
context = zmq.Context()
|
if not _check_port(port):
|
||||||
socket = context.socket(zmq.REQ)
|
# Port is available
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
return port, False
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send a special control message to query the server's model
|
# Port is in use, check if it's compatible
|
||||||
control_request = ["__QUERY_MODEL__"]
|
if _check_process_matches_config(port, model_name, passages_file):
|
||||||
request_bytes = msgpack.packb(control_request)
|
print(f"✅ Found compatible server on port {port}")
|
||||||
socket.send(request_bytes)
|
return port, True
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Port {port} has incompatible server, trying next port...")
|
||||||
|
|
||||||
# Wait for response
|
raise RuntimeError(
|
||||||
response_bytes = socket.recv()
|
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||||
response = msgpack.unpackb(response_bytes)
|
)
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the response contains the model name and if it matches
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
server_model = response[0]
|
|
||||||
return server_model == expected_model
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Could not query server model on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _update_server_model(port: int, new_model: str) -> bool:
|
|
||||||
"""
|
|
||||||
Send a control message to update the server's embedding model.
|
|
||||||
Returns True if successful, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
|
||||||
|
|
||||||
# Send a control message to update the model
|
|
||||||
control_request = ["__UPDATE_MODEL__", new_model]
|
|
||||||
request_bytes = msgpack.packb(control_request)
|
|
||||||
socket.send(request_bytes)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
response = msgpack.unpackb(response_bytes)
|
|
||||||
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
|
|
||||||
# Check if the update was successful
|
|
||||||
if isinstance(response, list) and len(response) > 0:
|
|
||||||
return response[0] == "SUCCESS"
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Could not update server model on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
"""
|
"""
|
||||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend_module_name: str):
|
def __init__(self, backend_module_name: str):
|
||||||
@@ -175,210 +158,162 @@ class EmbeddingServerManager:
|
|||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: Optional[int] = None
|
||||||
atexit.register(self.stop_server)
|
self._atexit_registered = False
|
||||||
|
|
||||||
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
|
def start_server(
|
||||||
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, int]:
|
||||||
"""
|
"""
|
||||||
Starts the embedding server process.
|
Starts the embedding server process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): The ZMQ port for the server.
|
port (int): The preferred ZMQ port for the server.
|
||||||
model_name (str): The name of the embedding model to use.
|
model_name (str): The name of the embedding model to use.
|
||||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
**kwargs: Additional arguments for the server.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the server is started successfully or already running, False otherwise.
|
tuple[bool, int]: (success, actual_port_used)
|
||||||
"""
|
"""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
passages_file = kwargs.get("passages_file")
|
||||||
# Even if we have a running process, check if model/meta path match
|
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||||
if self.server_port is not None:
|
|
||||||
port_in_use = _check_port(self.server_port)
|
|
||||||
if port_in_use:
|
|
||||||
print(
|
|
||||||
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check model compatibility
|
# Check if we have a compatible running server
|
||||||
model_matches = _check_server_model(self.server_port, model_name)
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
if model_matches:
|
assert self.server_port is not None, (
|
||||||
print(
|
"a compatible running server should set server_port"
|
||||||
f"✅ Existing server already using correct model: {model_name}"
|
)
|
||||||
)
|
return True, self.server_port
|
||||||
|
|
||||||
# Still check meta path if provided
|
# Find available port (compatible or free)
|
||||||
passages_file = kwargs.get("passages_file")
|
try:
|
||||||
if passages_file and str(passages_file).endswith(
|
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||||
".meta.json"
|
port, model_name, passages_file
|
||||||
):
|
)
|
||||||
meta_matches = _check_server_meta_path(
|
except RuntimeError as e:
|
||||||
self.server_port, str(passages_file)
|
print(f"❌ {e}")
|
||||||
)
|
return False, port
|
||||||
if not meta_matches:
|
|
||||||
print("⚠️ Updating meta path to: {passages_file}")
|
|
||||||
_update_server_meta_path(
|
|
||||||
self.server_port, str(passages_file)
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
if is_compatible:
|
||||||
else:
|
print(f"✅ Using existing compatible server on port {actual_port}")
|
||||||
print(
|
self.server_port = actual_port
|
||||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
self.server_process = None # We don't own this process
|
||||||
)
|
return True, actual_port
|
||||||
if not _update_server_model(self.server_port, model_name):
|
|
||||||
print(
|
|
||||||
"❌ Failed to update existing server model. Restarting server..."
|
|
||||||
)
|
|
||||||
self.stop_server()
|
|
||||||
# Continue to start new server below
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✅ Successfully updated existing server model to: {model_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also check meta path if provided
|
if actual_port != port:
|
||||||
passages_file = kwargs.get("passages_file")
|
print(f"⚠️ Using port {actual_port} instead of {port}")
|
||||||
if passages_file and str(passages_file).endswith(
|
|
||||||
".meta.json"
|
|
||||||
):
|
|
||||||
meta_matches = _check_server_meta_path(
|
|
||||||
self.server_port, str(passages_file)
|
|
||||||
)
|
|
||||||
if not meta_matches:
|
|
||||||
print("⚠️ Updating meta path to: {passages_file}")
|
|
||||||
_update_server_meta_path(
|
|
||||||
self.server_port, str(passages_file)
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
# Start new server
|
||||||
else:
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
# Server process exists but port not responding - restart
|
|
||||||
print("⚠️ Server process exists but not responding. Restarting...")
|
|
||||||
self.stop_server()
|
|
||||||
# Continue to start new server below
|
|
||||||
else:
|
|
||||||
# No port stored - restart
|
|
||||||
print("⚠️ No port information stored. Restarting server...")
|
|
||||||
self.stop_server()
|
|
||||||
# Continue to start new server below
|
|
||||||
|
|
||||||
if _check_port(port):
|
def _has_compatible_running_server(
|
||||||
# Port is in use, check if it's using the correct meta file and model
|
self, model_name: str, passages_file: str
|
||||||
passages_file = kwargs.get("passages_file")
|
) -> bool:
|
||||||
|
"""Check if we have a compatible running server."""
|
||||||
|
if not (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
|
print(
|
||||||
# Check model compatibility first
|
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
|
||||||
model_matches = _check_server_model(port, model_name)
|
)
|
||||||
if model_matches:
|
|
||||||
print(
|
|
||||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
|
||||||
)
|
|
||||||
if not _update_server_model(port, model_name):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
|
||||||
)
|
|
||||||
print(f"✅ Successfully updated server model to: {model_name}")
|
|
||||||
|
|
||||||
# Check meta path compatibility if provided
|
|
||||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
|
||||||
meta_matches = _check_server_meta_path(port, str(passages_file))
|
|
||||||
if not meta_matches:
|
|
||||||
print(
|
|
||||||
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
|
||||||
)
|
|
||||||
if not _update_server_meta_path(port, str(passages_file)):
|
|
||||||
raise RuntimeError(
|
|
||||||
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"✅ Successfully updated server meta path to: {passages_file}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ Server on port {port} is compatible and ready to use.")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print(
|
print("⚠️ Existing server process is incompatible. Should start a new server.")
|
||||||
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
return False
|
||||||
)
|
|
||||||
|
def _start_new_server(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Start a new embedding server on the given port."""
|
||||||
|
print(f"INFO: Starting embedding server on port {port}...")
|
||||||
|
|
||||||
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command = [
|
self._launch_server_process(command, port)
|
||||||
sys.executable,
|
return self._wait_for_server_ready(port)
|
||||||
"-m",
|
|
||||||
self.backend_module_name,
|
|
||||||
"--zmq-port",
|
|
||||||
str(port),
|
|
||||||
"--model-name",
|
|
||||||
model_name,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add extra arguments for specific backends
|
|
||||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
|
||||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
|
||||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
|
||||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
|
||||||
if embedding_mode != "sentence-transformers":
|
|
||||||
command.extend(["--embedding-mode", embedding_mode])
|
|
||||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
|
||||||
command.extend(["--disable-warmup"])
|
|
||||||
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running command from project root: {project_root}")
|
|
||||||
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
|
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
|
||||||
text=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
bufsize=1, # Line buffered
|
|
||||||
universal_newlines=True,
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 120, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print("✅ Embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print(
|
|
||||||
"❌ ERROR: Server process terminated unexpectedly during startup."
|
|
||||||
)
|
|
||||||
self._print_recent_output()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
|
||||||
)
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
||||||
return False
|
return False, port
|
||||||
|
|
||||||
|
def _build_server_command(
|
||||||
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
|
) -> list:
|
||||||
|
"""Build the command to start the embedding server."""
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
self.backend_module_name,
|
||||||
|
"--zmq-port",
|
||||||
|
str(port),
|
||||||
|
"--model-name",
|
||||||
|
model_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
if kwargs.get("passages_file"):
|
||||||
|
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||||
|
if embedding_mode != "sentence-transformers":
|
||||||
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
|
|
||||||
|
return command
|
||||||
|
|
||||||
|
def _launch_server_process(self, command: list, port: int) -> None:
|
||||||
|
"""Launch the server process."""
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
|
print(f"INFO: Command: {' '.join(command)}")
|
||||||
|
|
||||||
|
self.server_process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
cwd=project_root,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
bufsize=1,
|
||||||
|
universal_newlines=True,
|
||||||
|
)
|
||||||
|
self.server_port = port
|
||||||
|
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
# Register atexit callback only when we actually start a process
|
||||||
|
if not self._atexit_registered:
|
||||||
|
# Use a lambda to avoid issues with bound methods
|
||||||
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
|
self._atexit_registered = True
|
||||||
|
|
||||||
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
|
"""Wait for the server to be ready."""
|
||||||
|
max_wait, wait_interval = 120, 0.5
|
||||||
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
|
if _check_port(port):
|
||||||
|
print("✅ Embedding server is ready!")
|
||||||
|
threading.Thread(target=self._log_monitor, daemon=True).start()
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
if self.server_process.poll() is not None:
|
||||||
|
print("❌ ERROR: Server terminated during startup.")
|
||||||
|
self._print_recent_output()
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
|
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
|
||||||
|
self.stop_server()
|
||||||
|
return False, port
|
||||||
|
|
||||||
def _print_recent_output(self):
|
def _print_recent_output(self):
|
||||||
"""Print any recent output from the server process."""
|
"""Print any recent output from the server process."""
|
||||||
if not self.server_process or not self.server_process.stdout:
|
if not self.server_process or not self.server_process.stdout:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
# Read any available output
|
|
||||||
|
|
||||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||||
output = self.server_process.stdout.read()
|
output = self.server_process.stdout.read()
|
||||||
if output:
|
if output:
|
||||||
@@ -404,17 +339,26 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
"""Stops the embedding server process if it's running."""
|
"""Stops the embedding server process if it's running."""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
if not self.server_process:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.server_process.poll() is not None:
|
||||||
|
# Process already terminated
|
||||||
|
self.server_process = None
|
||||||
|
return
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
|
)
|
||||||
|
self.server_process.terminate()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.server_process.wait(timeout=5)
|
||||||
|
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
print(
|
print(
|
||||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
)
|
)
|
||||||
self.server_process.terminate()
|
self.server_process.kill()
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: Server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print(
|
|
||||||
"WARNING: Server process did not terminate gracefully, killing it."
|
|
||||||
)
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
|||||||
@@ -7,30 +7,37 @@ import importlib.metadata
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_backend(name: str):
|
def register_backend(name: str):
|
||||||
"""A decorator to register a new backend class."""
|
"""A decorator to register a new backend class."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
print(f"INFO: Registering backend '{name}'")
|
print(f"INFO: Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def autodiscover_backends():
|
def autodiscover_backends():
|
||||||
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||||
print("INFO: Starting backend auto-discovery...")
|
# print("INFO: Starting backend auto-discovery...")
|
||||||
discovered_backends = []
|
discovered_backends = []
|
||||||
for dist in importlib.metadata.distributions():
|
for dist in importlib.metadata.distributions():
|
||||||
dist_name = dist.metadata['name']
|
dist_name = dist.metadata["name"]
|
||||||
if dist_name.startswith('leann-backend-'):
|
if dist_name.startswith("leann-backend-"):
|
||||||
backend_module_name = dist_name.replace('-', '_')
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
discovered_backends.append(backend_module_name)
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
for backend_module_name in sorted(
|
||||||
|
discovered_backends
|
||||||
|
): # sort for deterministic loading
|
||||||
try:
|
try:
|
||||||
importlib.import_module(backend_module_name)
|
importlib.import_module(backend_module_name)
|
||||||
# Registration message is printed by the decorator
|
# Registration message is printed by the decorator
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
print("INFO: Backend auto-discovery finished.")
|
pass
|
||||||
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|||||||
@@ -43,8 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.label_map = self._load_label_map()
|
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name
|
backend_module_name=backend_module_name
|
||||||
)
|
)
|
||||||
@@ -58,17 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
with open(meta_path, "r", encoding="utf-8") as f:
|
with open(meta_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def _load_label_map(self) -> Dict[int, str]:
|
|
||||||
"""Loads the mapping from integer IDs to string IDs."""
|
|
||||||
label_map_file = self.index_dir / "leann.labels.map"
|
|
||||||
if not label_map_file.exists():
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(
|
||||||
self, passages_source_file: str, port: int, **kwargs
|
self, passages_source_file: str, port: int, **kwargs
|
||||||
) -> None:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Ensures the embedding server is running if recompute is needed.
|
Ensures the embedding server is running if recompute is needed.
|
||||||
This is a helper for subclasses.
|
This is a helper for subclasses.
|
||||||
@@ -80,7 +70,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
|
||||||
server_started = self.embedding_server_manager.start_server(
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
port=port,
|
port=port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
@@ -89,7 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
raise RuntimeError(
|
||||||
|
f"Failed to start embedding server on port {actual_port}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return actual_port
|
||||||
|
|
||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||||
@@ -106,12 +100,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
Query embedding as numpy array
|
Query embedding as numpy array
|
||||||
"""
|
"""
|
||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if (
|
if use_server_if_available:
|
||||||
use_server_if_available
|
|
||||||
and self.embedding_server_manager
|
|
||||||
and self.embedding_server_manager.server_process
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
|
# Ensure we have a server with passages_file for compatibility
|
||||||
|
passages_source_file = (
|
||||||
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
)
|
||||||
|
zmq_port = self._ensure_server_running(
|
||||||
|
str(passages_source_file), zmq_port
|
||||||
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
0:1
|
0:1
|
||||||
] # Return (1, D) shape
|
] # Return (1, D) shape
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ dependencies = [
|
|||||||
"llama-index-embeddings-huggingface>=0.5.5",
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
"mlx>=0.26.3",
|
"mlx>=0.26.3",
|
||||||
"mlx-lm>=0.26.0",
|
"mlx-lm>=0.26.0",
|
||||||
|
"psutil>=5.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
16
uv.lock
generated
16
uv.lock
generated
@@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
|
{ name = "tqdm" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }]
|
requires-dist = [
|
||||||
|
{ name = "numpy", specifier = ">=1.20.0" },
|
||||||
|
{ name = "tqdm", specifier = ">=4.60.0" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-workspace"
|
name = "leann-workspace"
|
||||||
@@ -1851,7 +1855,6 @@ dependencies = [
|
|||||||
{ name = "flask" },
|
{ name = "flask" },
|
||||||
{ name = "flask-compress" },
|
{ name = "flask-compress" },
|
||||||
{ name = "ipykernel" },
|
{ name = "ipykernel" },
|
||||||
{ name = "leann-backend-diskann" },
|
|
||||||
{ name = "leann-backend-hnsw" },
|
{ name = "leann-backend-hnsw" },
|
||||||
{ name = "leann-core" },
|
{ name = "leann-core" },
|
||||||
{ name = "llama-index" },
|
{ name = "llama-index" },
|
||||||
@@ -1867,6 +1870,7 @@ dependencies = [
|
|||||||
{ name = "ollama" },
|
{ name = "ollama" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "protobuf" },
|
{ name = "protobuf" },
|
||||||
|
{ name = "psutil" },
|
||||||
{ name = "pypdf2" },
|
{ name = "pypdf2" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "sentence-transformers" },
|
{ name = "sentence-transformers" },
|
||||||
@@ -1884,6 +1888,9 @@ dev = [
|
|||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
]
|
]
|
||||||
|
diskann = [
|
||||||
|
{ name = "leann-backend-diskann" },
|
||||||
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
@@ -1896,7 +1903,7 @@ requires-dist = [
|
|||||||
{ name = "flask-compress" },
|
{ name = "flask-compress" },
|
||||||
{ name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" },
|
{ name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" },
|
||||||
{ name = "ipykernel", specifier = "==6.29.5" },
|
{ name = "ipykernel", specifier = "==6.29.5" },
|
||||||
{ name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" },
|
{ name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" },
|
||||||
{ name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" },
|
{ name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" },
|
||||||
{ name = "leann-core", editable = "packages/leann-core" },
|
{ name = "leann-core", editable = "packages/leann-core" },
|
||||||
{ name = "llama-index", specifier = ">=0.12.44" },
|
{ name = "llama-index", specifier = ">=0.12.44" },
|
||||||
@@ -1912,6 +1919,7 @@ requires-dist = [
|
|||||||
{ name = "ollama" },
|
{ name = "ollama" },
|
||||||
{ name = "openai", specifier = ">=1.0.0" },
|
{ name = "openai", specifier = ">=1.0.0" },
|
||||||
{ name = "protobuf", specifier = "==4.25.3" },
|
{ name = "protobuf", specifier = "==4.25.3" },
|
||||||
|
{ name = "psutil", specifier = ">=5.8.0" },
|
||||||
{ name = "pypdf2", specifier = ">=3.0.0" },
|
{ name = "pypdf2", specifier = ">=3.0.0" },
|
||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
|
||||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
|
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
|
||||||
@@ -1922,7 +1930,7 @@ requires-dist = [
|
|||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm" },
|
||||||
]
|
]
|
||||||
provides-extras = ["dev"]
|
provides-extras = ["dev", "diskann"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-cloud"
|
name = "llama-cloud"
|
||||||
|
|||||||
Reference in New Issue
Block a user