Make DiskANN and HNSW work on main example (#2)
* fix: diskann zmq port and passages * feat: auto discovery of packages and fix passage gen for diskann
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -71,3 +71,6 @@ test_indices*/
|
|||||||
test_*.py
|
test_*.py
|
||||||
!tests/**
|
!tests/**
|
||||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||||
|
|
||||||
|
*.meta.json
|
||||||
|
*.passages.json
|
||||||
@@ -10,7 +10,6 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
import leann_backend_hnsw # Import to ensure backend registration
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -39,7 +38,7 @@ all_texts = []
|
|||||||
for doc in documents:
|
for doc in documents:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.text)
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index")
|
INDEX_DIR = Path("./test_pdf_index")
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
|
|||||||
|
|
||||||
# CSR compact mode with recompute
|
# CSR compact mode with recompute
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="diskann",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model="facebook/contriever",
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
@@ -74,7 +73,7 @@ async def main():
|
|||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
|
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, Any
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -36,7 +36,7 @@ def chdir(path):
|
|||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||||
num_vectors, dim = data.shape
|
num_vectors, dim = data.shape
|
||||||
with open(file_path, 'wb') as f:
|
with open(file_path, 'wb') as f:
|
||||||
f.write(struct.pack('I', num_vectors))
|
f.write(struct.pack('I', num_vectors))
|
||||||
@@ -67,6 +67,26 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
|
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
|
||||||
|
"""Generate passages file for recompute mode, mirroring HNSW backend."""
|
||||||
|
try:
|
||||||
|
chunks = kwargs.get('chunks', [])
|
||||||
|
if not chunks:
|
||||||
|
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
|
||||||
|
return
|
||||||
|
|
||||||
|
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)}
|
||||||
|
|
||||||
|
passages_file = index_dir / f"{index_prefix}.passages.json"
|
||||||
|
with open(passages_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(passages_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -95,6 +115,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
num_threads = build_kwargs.get("num_threads", 8)
|
num_threads = build_kwargs.get("num_threads", 8)
|
||||||
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
||||||
codebook_prefix = ""
|
codebook_prefix = ""
|
||||||
|
is_recompute = build_kwargs.get("is_recompute", False)
|
||||||
|
|
||||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||||
|
|
||||||
@@ -113,6 +134,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
codebook_prefix
|
codebook_prefix
|
||||||
)
|
)
|
||||||
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
||||||
|
if is_recompute:
|
||||||
|
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -141,16 +164,17 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
self.index_dir = path.parent
|
||||||
index_prefix = path.stem
|
self.index_prefix = path.stem
|
||||||
|
|
||||||
num_threads = kwargs.get("num_threads", 8)
|
num_threads = kwargs.get("num_threads", 8)
|
||||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
||||||
|
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
full_index_prefix = str(index_dir / index_prefix)
|
full_index_prefix = str(self.index_dir / self.index_prefix)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
self._index = diskannpy.StaticDiskFloatIndex(
|
||||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
|
||||||
)
|
)
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
@@ -161,7 +185,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||||
complexity = kwargs.get("complexity", 256)
|
complexity = kwargs.get("complexity", 256)
|
||||||
beam_width = kwargs.get("beam_width", 4)
|
beam_width = kwargs.get("beam_width", 4)
|
||||||
|
|
||||||
@@ -172,23 +196,36 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
||||||
batch_recompute = kwargs.get("batch_recompute", False)
|
batch_recompute = kwargs.get("batch_recompute", False)
|
||||||
global_pruning = kwargs.get("global_pruning", False)
|
global_pruning = kwargs.get("global_pruning", False)
|
||||||
|
port = kwargs.get("zmq_port", self.zmq_port)
|
||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
if recompute_beighbor_embeddings:
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
|
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
|
||||||
|
|
||||||
zmq_port = kwargs.get("zmq_port", 6666)
|
passages_file = kwargs.get("passages_file")
|
||||||
|
if not passages_file:
|
||||||
|
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||||
|
if potential_passages_file.exists():
|
||||||
|
passages_file = str(potential_passages_file)
|
||||||
|
print(f"INFO: Automatically found passages file: {passages_file}")
|
||||||
|
|
||||||
|
if not passages_file:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Recompute mode is enabled, but no passages file was found. "
|
||||||
|
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
|
||||||
|
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
|
||||||
|
)
|
||||||
|
|
||||||
server_started = self.embedding_server_manager.start_server(
|
server_started = self.embedding_server_manager.start_server(
|
||||||
port=zmq_port,
|
port=self.zmq_port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
distance_metric=self.distance_metric
|
distance_metric=self.distance_metric,
|
||||||
|
passages_file=passages_file
|
||||||
)
|
)
|
||||||
|
|
||||||
if not server_started:
|
if not server_started:
|
||||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
|
||||||
recompute_beighbor_embeddings = False
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import pickle
|
|||||||
import argparse
|
import argparse
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional, Union
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
import os
|
import os
|
||||||
@@ -17,49 +19,49 @@ import numpy as np
|
|||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
# 简化的文档存储 - 替代 LazyPassages
|
# --- New Passage Loader from HNSW backend ---
|
||||||
class SimpleDocumentStore:
|
class SimplePassageLoader:
|
||||||
"""简化的文档存储,支持任意ID"""
|
"""
|
||||||
def __init__(self, documents: dict = None):
|
Simple passage loader that replaces config.py dependencies
|
||||||
self.documents = documents or {}
|
"""
|
||||||
# 默认演示文档
|
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||||
self.default_docs = {
|
self.passages_data = passages_data or {}
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, doc_id):
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||||
doc_id = int(doc_id)
|
"""Get passage by ID"""
|
||||||
|
str_id = str(passage_id)
|
||||||
# 优先使用指定的文档
|
if str_id in self.passages_data:
|
||||||
if doc_id in self.documents:
|
return {"text": self.passages_data[str_id]}
|
||||||
return {"text": self.documents[doc_id]}
|
else:
|
||||||
|
# Return empty text for missing passages
|
||||||
# 其次使用默认演示文档
|
return {"text": ""}
|
||||||
if doc_id in self.default_docs:
|
|
||||||
return {"text": self.default_docs[doc_id]}
|
|
||||||
|
|
||||||
# 对于任意其他ID,返回通用文档
|
|
||||||
fallback_docs = [
|
|
||||||
"This is a general document about technology and programming concepts.",
|
|
||||||
"This document discusses machine learning and artificial intelligence topics.",
|
|
||||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
|
||||||
"This is a document about software engineering and development practices.",
|
|
||||||
"This content focuses on databases, data management, and information systems."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 根据ID选择一个fallback文档
|
|
||||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
|
||||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.documents) + len(self.default_docs)
|
return len(self.passages_data)
|
||||||
|
|
||||||
|
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||||
|
"""
|
||||||
|
Load passages from a JSON file
|
||||||
|
Expected format: {"passage_id": "passage_text", ...}
|
||||||
|
"""
|
||||||
|
if not os.path.exists(passages_file):
|
||||||
|
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
||||||
|
return SimplePassageLoader()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||||
|
passages_data = json.load(f)
|
||||||
|
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
||||||
|
return SimplePassageLoader(passages_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading passages from {passages_file}: {e}")
|
||||||
|
return SimplePassageLoader()
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
def create_embedding_server_thread(
|
||||||
zmq_port=5555,
|
zmq_port=5555,
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
在当前线程中创建并运行 embedding server
|
在当前线程中创建并运行 embedding server
|
||||||
@@ -109,15 +111,14 @@ def create_embedding_server_thread(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
print(f"WARNING: Model optimization failed: {e}")
|
||||||
|
|
||||||
# 默认演示文档
|
# Load passages from file if provided
|
||||||
demo_documents = {
|
if passages_file and os.path.exists(passages_file):
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
passages = load_passages_from_file(passages_file)
|
||||||
1: "Machine learning builds systems that learn from data.",
|
else:
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||||
}
|
passages = SimplePassageLoader()
|
||||||
|
|
||||||
passages = SimpleDocumentStore(demo_documents)
|
print(f"INFO: Loaded {len(passages)} passages.")
|
||||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
class DeviceTimer:
|
||||||
"""设备计时器"""
|
"""设备计时器"""
|
||||||
@@ -264,7 +265,13 @@ def create_embedding_server_thread(
|
|||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
txtinfo = passages[nid]
|
txtinfo = passages[nid]
|
||||||
txt = txtinfo["text"]
|
txt = txtinfo["text"]
|
||||||
texts.append(txt)
|
if txt:
|
||||||
|
texts.append(txt)
|
||||||
|
else:
|
||||||
|
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
|
||||||
|
# 但将其ID记录为缺失
|
||||||
|
texts.append("")
|
||||||
|
missing_ids.append(nid)
|
||||||
lookup_timer.print_elapsed()
|
lookup_timer.print_elapsed()
|
||||||
|
|
||||||
if missing_ids:
|
if missing_ids:
|
||||||
@@ -360,18 +367,20 @@ def create_embedding_server(
|
|||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
lazy_load_passages=False,
|
lazy_load_passages=False,
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
原有的 create_embedding_server 函数保持不变
|
原有的 create_embedding_server 函数保持不变
|
||||||
这个是阻塞版本,用于直接运行
|
这个是阻塞版本,用于直接运行
|
||||||
"""
|
"""
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
parser = argparse.ArgumentParser(description="Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||||
|
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||||
@@ -394,4 +403,5 @@ if __name__ == "__main__":
|
|||||||
max_batch_size=args.max_batch_size,
|
max_batch_size=args.max_batch_size,
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
lazy_load_passages=args.lazy_load_passages,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
|
passages_file=args.passages_file,
|
||||||
)
|
)
|
||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 015c201141...c7a9d681cb
@@ -303,7 +303,8 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||||
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
|
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
|
||||||
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
|
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
|
||||||
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
|
|
||||||
|
self.zmq_port = kwargs.get("zmq_port", 5557)
|
||||||
|
|
||||||
if self.is_pruned and not hnsw_config.is_recompute:
|
if self.is_pruned and not hnsw_config.is_recompute:
|
||||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
||||||
@@ -361,13 +362,15 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
faiss.normalize_L2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._index.hnsw.efSearch = ef
|
params = faiss.SearchParametersHNSW()
|
||||||
|
params.efSearch = ef
|
||||||
|
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
|
||||||
|
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
||||||
|
|
||||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
|
||||||
|
|
||||||
return {"labels": labels, "distances": distances}
|
return {"labels": labels, "distances": distances}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# packages/leann-core/src/leann/__init__.py
|
||||||
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||||
|
|
||||||
|
autodiscover_backends()
|
||||||
|
|
||||||
|
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
# packages/leann-core/src/leann/registry.py
|
# packages/leann-core/src/leann/registry.py
|
||||||
|
|
||||||
from typing import Dict, TYPE_CHECKING
|
from typing import Dict, TYPE_CHECKING
|
||||||
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
@@ -12,4 +15,22 @@ def register_backend(name: str):
|
|||||||
print(f"INFO: Registering backend '{name}'")
|
print(f"INFO: Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def autodiscover_backends():
|
||||||
|
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||||
|
print("INFO: Starting backend auto-discovery...")
|
||||||
|
discovered_backends = []
|
||||||
|
for dist in importlib.metadata.distributions():
|
||||||
|
dist_name = dist.metadata['name']
|
||||||
|
if dist_name.startswith('leann-backend-'):
|
||||||
|
backend_module_name = dist_name.replace('-', '_')
|
||||||
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
|
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||||
|
try:
|
||||||
|
importlib.import_module(backend_module_name)
|
||||||
|
# Registration message is printed by the decorator
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
|
print("INFO: Backend auto-discovery finished.")
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user