Compare commits
4 Commits
fix-arm64-
...
fix-all
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16705fc44a | ||
|
|
5611f708e9 | ||
|
|
b4ae57b2c0 | ||
|
|
5659174635 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -71,3 +71,6 @@ test_indices*/
|
|||||||
test_*.py
|
test_*.py
|
||||||
!tests/**
|
!tests/**
|
||||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||||
|
|
||||||
|
*.meta.json
|
||||||
|
*.passages.json
|
||||||
@@ -74,7 +74,7 @@ def main():
|
|||||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||||
print(">>> Basic search results <<<")
|
print(">>> Basic search results <<<")
|
||||||
for i, res in enumerate(results, 1):
|
for i, res in enumerate(results, 1):
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||||
|
|
||||||
# --- 3. Recompute search demo ---
|
# --- 3. Recompute search demo ---
|
||||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
||||||
@@ -107,7 +107,7 @@ def main():
|
|||||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||||
print(">>> Recompute search results <<<")
|
print(">>> Recompute search results <<<")
|
||||||
for i, res in enumerate(recompute_results, 1):
|
for i, res in enumerate(recompute_results, 1):
|
||||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
print(f"\n--- Result comparison ---")
|
print(f"\n--- Result comparison ---")
|
||||||
@@ -116,8 +116,8 @@ def main():
|
|||||||
|
|
||||||
print("\nBasic search vs Recompute results:")
|
print("\nBasic search vs Recompute results:")
|
||||||
for i in range(min(len(results), len(recompute_results))):
|
for i in range(min(len(results), len(recompute_results))):
|
||||||
basic_score = results[i]['score']
|
basic_score = results[i].score
|
||||||
recompute_score = recompute_results[i]['score']
|
recompute_score = recompute_results[i].score
|
||||||
score_diff = abs(basic_score - recompute_score)
|
score_diff = abs(basic_score - recompute_score)
|
||||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
import leann_backend_hnsw # Import to ensure backend registration
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -39,7 +38,7 @@ all_texts = []
|
|||||||
for doc in documents:
|
for doc in documents:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.text)
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index")
|
INDEX_DIR = Path("./test_pdf_index")
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
|
|||||||
|
|
||||||
# CSR compact mode with recompute
|
# CSR compact mode with recompute
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="diskann",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model="facebook/contriever",
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
@@ -74,7 +73,7 @@ async def main():
|
|||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
|
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, Any, List
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -11,6 +11,7 @@ import atexit
|
|||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import pickle
|
||||||
|
|
||||||
from leann.embedding_server_manager import EmbeddingServerManager
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
@@ -19,13 +20,13 @@ from leann.interface import (
|
|||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendSearcherInterface
|
LeannBackendSearcherInterface
|
||||||
)
|
)
|
||||||
from . import _diskannpy as diskannpy
|
def _get_diskann_metrics():
|
||||||
|
from . import _diskannpy as diskannpy
|
||||||
METRIC_MAP = {
|
return {
|
||||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||||
"l2": diskannpy.Metric.L2,
|
"l2": diskannpy.Metric.L2,
|
||||||
"cosine": diskannpy.Metric.COSINE,
|
"cosine": diskannpy.Metric.COSINE,
|
||||||
}
|
}
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def chdir(path):
|
def chdir(path):
|
||||||
@@ -36,7 +37,7 @@ def chdir(path):
|
|||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||||
num_vectors, dim = data.shape
|
num_vectors, dim = data.shape
|
||||||
with open(file_path, 'wb') as f:
|
with open(file_path, 'wb') as f:
|
||||||
f.write(struct.pack('I', num_vectors))
|
f.write(struct.pack('I', num_vectors))
|
||||||
@@ -67,7 +68,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
|
||||||
|
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
@@ -82,8 +84,15 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
|
# Create label map: integer -> string_id
|
||||||
|
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||||
|
label_map_file = index_dir / "leann.labels.map"
|
||||||
|
with open(label_map_file, 'wb') as f:
|
||||||
|
pickle.dump(label_map, f)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
||||||
|
METRIC_MAP = _get_diskann_metrics()
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
metric_enum = METRIC_MAP.get(metric_str)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||||
@@ -99,6 +108,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import _diskannpy as diskannpy
|
||||||
with chdir(index_dir):
|
with chdir(index_dir):
|
||||||
diskannpy.build_disk_float_index(
|
diskannpy.build_disk_float_index(
|
||||||
metric_enum,
|
metric_enum,
|
||||||
@@ -127,30 +137,38 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
if not self.meta:
|
if not self.meta:
|
||||||
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
|
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
|
||||||
|
|
||||||
dimensions = self.meta.get("dimensions")
|
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Dimensions not found in Leann metadata.")
|
|
||||||
|
|
||||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
|
||||||
metric_enum = METRIC_MAP.get(self.distance_metric)
|
|
||||||
if metric_enum is None:
|
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
|
||||||
|
|
||||||
self.embedding_model = self.meta.get("embedding_model")
|
self.embedding_model = self.meta.get("embedding_model")
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
self.index_dir = path.parent
|
||||||
index_prefix = path.stem
|
self.index_prefix = path.stem
|
||||||
|
|
||||||
|
# Load the label map
|
||||||
|
label_map_file = self.index_dir / "leann.labels.map"
|
||||||
|
if not label_map_file.exists():
|
||||||
|
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||||
|
|
||||||
|
with open(label_map_file, 'rb') as f:
|
||||||
|
self.label_map = pickle.load(f)
|
||||||
|
|
||||||
|
# Extract parameters for DiskANN
|
||||||
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||||
|
METRIC_MAP = _get_diskann_metrics()
|
||||||
|
metric_enum = METRIC_MAP.get(distance_metric)
|
||||||
|
if metric_enum is None:
|
||||||
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||||
|
|
||||||
num_threads = kwargs.get("num_threads", 8)
|
num_threads = kwargs.get("num_threads", 8)
|
||||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
||||||
|
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
full_index_prefix = str(index_dir / index_prefix)
|
from . import _diskannpy as diskannpy
|
||||||
|
full_index_prefix = str(self.index_dir / self.index_prefix)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
self._index = diskannpy.StaticDiskFloatIndex(
|
||||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
|
||||||
)
|
)
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
@@ -161,7 +179,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||||
complexity = kwargs.get("complexity", 256)
|
complexity = kwargs.get("complexity", 256)
|
||||||
beam_width = kwargs.get("beam_width", 4)
|
beam_width = kwargs.get("beam_width", 4)
|
||||||
|
|
||||||
@@ -172,23 +190,32 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
||||||
batch_recompute = kwargs.get("batch_recompute", False)
|
batch_recompute = kwargs.get("batch_recompute", False)
|
||||||
global_pruning = kwargs.get("global_pruning", False)
|
global_pruning = kwargs.get("global_pruning", False)
|
||||||
|
port = kwargs.get("zmq_port", self.zmq_port)
|
||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
if recompute_beighbor_embeddings:
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
|
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
|
||||||
|
|
||||||
zmq_port = kwargs.get("zmq_port", 6666)
|
passages_file = kwargs.get("passages_file")
|
||||||
|
if not passages_file:
|
||||||
|
# Get the passages file path from meta.json
|
||||||
|
if 'passage_sources' in self.meta and self.meta['passage_sources']:
|
||||||
|
passage_source = self.meta['passage_sources'][0]
|
||||||
|
passages_file = passage_source['path']
|
||||||
|
print(f"INFO: Found passages file from metadata: {passages_file}")
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
|
||||||
|
|
||||||
server_started = self.embedding_server_manager.start_server(
|
server_started = self.embedding_server_manager.start_server(
|
||||||
port=zmq_port,
|
port=self.zmq_port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
distance_metric=self.distance_metric
|
distance_metric=kwargs.get("distance_metric", "mips"),
|
||||||
|
passages_file=passages_file
|
||||||
)
|
)
|
||||||
|
|
||||||
if not server_started:
|
if not server_started:
|
||||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
|
||||||
recompute_beighbor_embeddings = False
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
@@ -211,11 +238,23 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
batch_recompute,
|
batch_recompute,
|
||||||
global_pruning
|
global_pruning
|
||||||
)
|
)
|
||||||
return {"labels": labels, "distances": distances}
|
|
||||||
|
# Convert integer labels to string IDs
|
||||||
|
string_labels = []
|
||||||
|
for batch_labels in labels:
|
||||||
|
batch_string_labels = []
|
||||||
|
for int_label in batch_labels:
|
||||||
|
if int_label in self.label_map:
|
||||||
|
batch_string_labels.append(self.label_map[int_label])
|
||||||
|
else:
|
||||||
|
batch_string_labels.append(f"unknown_{int_label}")
|
||||||
|
string_labels.append(batch_string_labels)
|
||||||
|
|
||||||
|
return {"labels": string_labels, "distances": distances}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
|
||||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import pickle
|
|||||||
import argparse
|
import argparse
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional, Union
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
import os
|
import os
|
||||||
@@ -17,49 +19,76 @@ import numpy as np
|
|||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
# 简化的文档存储 - 替代 LazyPassages
|
# --- New Passage Loader from HNSW backend ---
|
||||||
class SimpleDocumentStore:
|
class SimplePassageLoader:
|
||||||
"""简化的文档存储,支持任意ID"""
|
"""
|
||||||
def __init__(self, documents: dict = None):
|
Simple passage loader that replaces config.py dependencies
|
||||||
self.documents = documents or {}
|
"""
|
||||||
# 默认演示文档
|
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||||
self.default_docs = {
|
self.passages_data = passages_data or {}
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
|
||||||
1: "Machine learning builds systems that learn from data.",
|
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, doc_id):
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||||
doc_id = int(doc_id)
|
"""Get passage by ID"""
|
||||||
|
str_id = str(passage_id)
|
||||||
# 优先使用指定的文档
|
if str_id in self.passages_data:
|
||||||
if doc_id in self.documents:
|
return {"text": self.passages_data[str_id]}
|
||||||
return {"text": self.documents[doc_id]}
|
else:
|
||||||
|
# Return empty text for missing passages
|
||||||
# 其次使用默认演示文档
|
return {"text": ""}
|
||||||
if doc_id in self.default_docs:
|
|
||||||
return {"text": self.default_docs[doc_id]}
|
|
||||||
|
|
||||||
# 对于任意其他ID,返回通用文档
|
|
||||||
fallback_docs = [
|
|
||||||
"This is a general document about technology and programming concepts.",
|
|
||||||
"This document discusses machine learning and artificial intelligence topics.",
|
|
||||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
|
||||||
"This is a document about software engineering and development practices.",
|
|
||||||
"This content focuses on databases, data management, and information systems."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 根据ID选择一个fallback文档
|
|
||||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
|
||||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.documents) + len(self.default_docs)
|
return len(self.passages_data)
|
||||||
|
|
||||||
|
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||||
|
"""
|
||||||
|
Load passages from a JSONL file with label map support
|
||||||
|
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
if not os.path.exists(passages_file):
|
||||||
|
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||||
|
|
||||||
|
if not passages_file.endswith('.jsonl'):
|
||||||
|
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||||
|
|
||||||
|
# Load label map (int -> string_id)
|
||||||
|
passages_dir = Path(passages_file).parent
|
||||||
|
label_map_file = passages_dir / "leann.labels.map"
|
||||||
|
|
||||||
|
label_map = {}
|
||||||
|
if label_map_file.exists():
|
||||||
|
with open(label_map_file, 'rb') as f:
|
||||||
|
label_map = pickle.load(f)
|
||||||
|
print(f"Loaded label map with {len(label_map)} entries")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||||
|
|
||||||
|
# Load passages by string ID
|
||||||
|
string_id_passages = {}
|
||||||
|
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
string_id_passages[passage['id']] = passage['text']
|
||||||
|
|
||||||
|
# Create int ID -> text mapping using label map
|
||||||
|
passages_data = {}
|
||||||
|
for int_id, string_id in label_map.items():
|
||||||
|
if string_id in string_id_passages:
|
||||||
|
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||||
|
else:
|
||||||
|
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||||
|
return SimplePassageLoader(passages_data)
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
def create_embedding_server_thread(
|
||||||
zmq_port=5555,
|
zmq_port=5555,
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
在当前线程中创建并运行 embedding server
|
在当前线程中创建并运行 embedding server
|
||||||
@@ -109,15 +138,14 @@ def create_embedding_server_thread(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
print(f"WARNING: Model optimization failed: {e}")
|
||||||
|
|
||||||
# 默认演示文档
|
# Load passages from file if provided
|
||||||
demo_documents = {
|
if passages_file and os.path.exists(passages_file):
|
||||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
passages = load_passages_from_file(passages_file)
|
||||||
1: "Machine learning builds systems that learn from data.",
|
else:
|
||||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||||
}
|
passages = SimplePassageLoader()
|
||||||
|
|
||||||
passages = SimpleDocumentStore(demo_documents)
|
print(f"INFO: Loaded {len(passages)} passages.")
|
||||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
class DeviceTimer:
|
||||||
"""设备计时器"""
|
"""设备计时器"""
|
||||||
@@ -264,7 +292,13 @@ def create_embedding_server_thread(
|
|||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
txtinfo = passages[nid]
|
txtinfo = passages[nid]
|
||||||
txt = txtinfo["text"]
|
txt = txtinfo["text"]
|
||||||
texts.append(txt)
|
if txt:
|
||||||
|
texts.append(txt)
|
||||||
|
else:
|
||||||
|
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
|
||||||
|
# 但将其ID记录为缺失
|
||||||
|
texts.append("")
|
||||||
|
missing_ids.append(nid)
|
||||||
lookup_timer.print_elapsed()
|
lookup_timer.print_elapsed()
|
||||||
|
|
||||||
if missing_ids:
|
if missing_ids:
|
||||||
@@ -360,18 +394,20 @@ def create_embedding_server(
|
|||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
lazy_load_passages=False,
|
lazy_load_passages=False,
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
原有的 create_embedding_server 函数保持不变
|
原有的 create_embedding_server 函数保持不变
|
||||||
这个是阻塞版本,用于直接运行
|
这个是阻塞版本,用于直接运行
|
||||||
"""
|
"""
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
parser = argparse.ArgumentParser(description="Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||||
|
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||||
@@ -394,4 +430,5 @@ if __name__ == "__main__":
|
|||||||
max_batch_size=args.max_batch_size,
|
max_batch_size=args.max_batch_size,
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
lazy_load_passages=args.lazy_load_passages,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
|
passages_file=args.passages_file,
|
||||||
)
|
)
|
||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 015c201141...c7a9d681cb
@@ -468,16 +468,27 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
# --- Write CSR HNSW graph data using unified function ---
|
# --- Write CSR HNSW graph data using unified function ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
||||||
|
|
||||||
# Determine storage fourcc based on prune_embeddings
|
# Determine storage fourcc and data based on prune_embeddings
|
||||||
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
|
|
||||||
if prune_embeddings:
|
if prune_embeddings:
|
||||||
print(f" Pruning embeddings: Writing NULL storage marker.")
|
print(f" Pruning embeddings: Writing NULL storage marker.")
|
||||||
storage_data = b''
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
storage_data = b''
|
||||||
|
else:
|
||||||
|
# Keep embeddings - read and preserve original storage data
|
||||||
|
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
|
print(f" Preserving embeddings: Reading original storage data...")
|
||||||
|
storage_data = f_in.read() # Read remaining storage data
|
||||||
|
output_storage_fourcc = storage_fourcc
|
||||||
|
print(f" Read {len(storage_data)} bytes of storage data")
|
||||||
|
else:
|
||||||
|
print(f" No embeddings found in original file (NULL storage)")
|
||||||
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
storage_data = b''
|
||||||
|
|
||||||
# Use the unified write function
|
# Use the unified write function
|
||||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||||
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'')
|
compact_neighbors_data, output_storage_fourcc, storage_data)
|
||||||
|
|
||||||
# Clean up memory
|
# Clean up memory
|
||||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, List
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -11,6 +11,7 @@ import atexit
|
|||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import pickle
|
||||||
|
|
||||||
from leann.embedding_server_manager import EmbeddingServerManager
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
@@ -74,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
if self.is_recompute and not self.is_compact:
|
if self.is_recompute and not self.is_compact:
|
||||||
raise ValueError("is_recompute requires is_compact=True for efficiency")
|
raise ValueError("is_recompute requires is_compact=True for efficiency")
|
||||||
|
|
||||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||||
"""Build HNSW index using FAISS"""
|
"""Build HNSW index using FAISS"""
|
||||||
from . import faiss
|
from . import faiss
|
||||||
|
|
||||||
@@ -89,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
if not data.flags['C_CONTIGUOUS']:
|
if not data.flags['C_CONTIGUOUS']:
|
||||||
data = np.ascontiguousarray(data)
|
data = np.ascontiguousarray(data)
|
||||||
|
|
||||||
|
# Create label map: integer -> string_id
|
||||||
|
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||||
|
label_map_file = index_dir / "leann.labels.map"
|
||||||
|
with open(label_map_file, 'wb') as f:
|
||||||
|
pickle.dump(label_map, f)
|
||||||
|
|
||||||
metric_str = self.distance_metric.lower()
|
metric_str = self.distance_metric.lower()
|
||||||
metric_enum = get_metric_map().get(metric_str)
|
metric_enum = get_metric_map().get(metric_str)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
@@ -119,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
if self.is_compact:
|
if self.is_compact:
|
||||||
self._convert_to_csr(index_file)
|
self._convert_to_csr(index_file)
|
||||||
|
|
||||||
if self.is_recompute:
|
|
||||||
self._generate_passages_file(index_dir, index_prefix, **kwargs)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
|
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -155,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
|
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
|
|
||||||
"""Generate passages file for recompute mode"""
|
|
||||||
try:
|
|
||||||
chunks = kwargs.get('chunks', [])
|
|
||||||
if not chunks:
|
|
||||||
print("INFO: No chunks data provided, skipping passages file generation")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Generate node_id to text mapping
|
|
||||||
passages_data = {}
|
|
||||||
for node_id, chunk in enumerate(chunks):
|
|
||||||
passages_data[str(node_id)] = chunk["text"]
|
|
||||||
|
|
||||||
# Save passages file
|
|
||||||
passages_file = index_dir / f"{index_prefix}.passages.json"
|
|
||||||
with open(passages_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(passages_data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
|
|
||||||
# Don't raise - this is not critical for index building
|
|
||||||
pass
|
|
||||||
|
|
||||||
class HNSWSearcher(LeannBackendSearcherInterface):
|
class HNSWSearcher(LeannBackendSearcherInterface):
|
||||||
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
|
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
|
||||||
@@ -282,6 +262,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
self.index_dir = path.parent
|
self.index_dir = path.parent
|
||||||
self.index_prefix = path.stem
|
self.index_prefix = path.stem
|
||||||
|
|
||||||
|
# Load the label map
|
||||||
|
label_map_file = self.index_dir / "leann.labels.map"
|
||||||
|
if not label_map_file.exists():
|
||||||
|
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||||
|
|
||||||
|
with open(label_map_file, 'rb') as f:
|
||||||
|
self.label_map = pickle.load(f)
|
||||||
|
|
||||||
index_file = self.index_dir / f"{self.index_prefix}.index"
|
index_file = self.index_dir / f"{self.index_prefix}.index"
|
||||||
if not index_file.exists():
|
if not index_file.exists():
|
||||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||||
@@ -303,7 +291,8 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||||
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
|
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
|
||||||
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
|
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
|
||||||
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
|
|
||||||
|
self.zmq_port = kwargs.get("zmq_port", 5557)
|
||||||
|
|
||||||
if self.is_pruned and not hnsw_config.is_recompute:
|
if self.is_pruned and not hnsw_config.is_recompute:
|
||||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
||||||
@@ -335,12 +324,13 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
|
|
||||||
passages_file = kwargs.get("passages_file")
|
passages_file = kwargs.get("passages_file")
|
||||||
if not passages_file:
|
if not passages_file:
|
||||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
# Get the passages file path from meta.json
|
||||||
if potential_passages_file.exists():
|
if 'passage_sources' in self.meta and self.meta['passage_sources']:
|
||||||
passages_file = str(potential_passages_file)
|
passage_source = self.meta['passage_sources'][0]
|
||||||
print(f"INFO: Automatically found passages file: {passages_file}")
|
passages_file = passage_source['path']
|
||||||
|
print(f"INFO: Found passages file from metadata: {passages_file}")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.")
|
raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
|
||||||
|
|
||||||
zmq_port = kwargs.get("zmq_port", 5557)
|
zmq_port = kwargs.get("zmq_port", 5557)
|
||||||
server_started = self.embedding_server_manager.start_server(
|
server_started = self.embedding_server_manager.start_server(
|
||||||
@@ -361,15 +351,28 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
faiss.normalize_L2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._index.hnsw.efSearch = ef
|
params = faiss.SearchParametersHNSW()
|
||||||
|
params.efSearch = ef
|
||||||
|
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
|
||||||
|
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
||||||
|
|
||||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
|
||||||
|
|
||||||
return {"labels": labels, "distances": distances}
|
# Convert integer labels to string IDs
|
||||||
|
string_labels = []
|
||||||
|
for batch_labels in labels:
|
||||||
|
batch_string_labels = []
|
||||||
|
for int_label in batch_labels:
|
||||||
|
if int_label in self.label_map:
|
||||||
|
batch_string_labels.append(self.label_map[int_label])
|
||||||
|
else:
|
||||||
|
batch_string_labels.append(f"unknown_{int_label}")
|
||||||
|
string_labels.append(batch_string_labels)
|
||||||
|
|
||||||
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
|
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
|
||||||
|
|||||||
@@ -58,21 +58,46 @@ class SimplePassageLoader:
|
|||||||
|
|
||||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||||
"""
|
"""
|
||||||
Load passages from a JSON file
|
Load passages from a JSONL file with label map support
|
||||||
Expected format: {"passage_id": "passage_text", ...}
|
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(passages_file):
|
if not os.path.exists(passages_file):
|
||||||
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||||
return SimplePassageLoader()
|
|
||||||
|
|
||||||
try:
|
if not passages_file.endswith('.jsonl'):
|
||||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||||
passages_data = json.load(f)
|
|
||||||
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
# Load label map (int -> string_id)
|
||||||
return SimplePassageLoader(passages_data)
|
passages_dir = Path(passages_file).parent
|
||||||
except Exception as e:
|
label_map_file = passages_dir / "leann.labels.map"
|
||||||
print(f"Error loading passages from {passages_file}: {e}")
|
|
||||||
return SimplePassageLoader()
|
label_map = {}
|
||||||
|
if label_map_file.exists():
|
||||||
|
import pickle
|
||||||
|
with open(label_map_file, 'rb') as f:
|
||||||
|
label_map = pickle.load(f)
|
||||||
|
print(f"Loaded label map with {len(label_map)} entries")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||||
|
|
||||||
|
# Load passages by string ID
|
||||||
|
string_id_passages = {}
|
||||||
|
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
string_id_passages[passage['id']] = passage['text']
|
||||||
|
|
||||||
|
# Create int ID -> text mapping using label map
|
||||||
|
passages_data = {}
|
||||||
|
for int_id, string_id in label_map.items():
|
||||||
|
if string_id in string_id_passages:
|
||||||
|
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||||
|
else:
|
||||||
|
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||||
|
return SimplePassageLoader(passages_data)
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# packages/leann-core/src/leann/__init__.py
|
||||||
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||||
|
|
||||||
|
autodiscover_backends()
|
||||||
|
|
||||||
|
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]
|
||||||
@@ -7,6 +7,8 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import openai
|
import openai
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
import uuid
|
||||||
|
import pickle
|
||||||
|
|
||||||
# --- Helper Functions for Embeddings ---
|
# --- Helper Functions for Embeddings ---
|
||||||
|
|
||||||
@@ -56,11 +58,45 @@ def _get_embedding_dimensions(model_name: str) -> int:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
"""Represents a single search result."""
|
"""Represents a single search result."""
|
||||||
id: int
|
id: str
|
||||||
score: float
|
score: float
|
||||||
text: str
|
text: str
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class PassageManager:
|
||||||
|
"""Manages passage data and lazy loading from JSONL files."""
|
||||||
|
|
||||||
|
def __init__(self, passage_sources: List[Dict[str, Any]]):
|
||||||
|
self.offset_maps = {}
|
||||||
|
self.passage_files = {}
|
||||||
|
|
||||||
|
for source in passage_sources:
|
||||||
|
if source["type"] == "jsonl":
|
||||||
|
passage_file = source["path"]
|
||||||
|
index_file = source["index_path"]
|
||||||
|
|
||||||
|
if not os.path.exists(index_file):
|
||||||
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
|
with open(index_file, 'rb') as f:
|
||||||
|
offset_map = pickle.load(f)
|
||||||
|
|
||||||
|
self.offset_maps[passage_file] = offset_map
|
||||||
|
self.passage_files[passage_file] = passage_file
|
||||||
|
|
||||||
|
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||||
|
"""Lazy load a passage by ID."""
|
||||||
|
for passage_file, offset_map in self.offset_maps.items():
|
||||||
|
if passage_id in offset_map:
|
||||||
|
offset = offset_map[passage_id]
|
||||||
|
with open(passage_file, 'r', encoding='utf-8') as f:
|
||||||
|
f.seek(offset)
|
||||||
|
line = f.readline()
|
||||||
|
return json.loads(line)
|
||||||
|
|
||||||
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
# --- Core Classes ---
|
# --- Core Classes ---
|
||||||
|
|
||||||
class LeannBuilder:
|
class LeannBuilder:
|
||||||
@@ -82,7 +118,26 @@ class LeannBuilder:
|
|||||||
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
|
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
self.chunks.append({"text": text, "metadata": metadata or {}})
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Check if ID is provided in metadata
|
||||||
|
passage_id = metadata.get('id')
|
||||||
|
if passage_id is None:
|
||||||
|
passage_id = str(uuid.uuid4())
|
||||||
|
else:
|
||||||
|
# Validate uniqueness
|
||||||
|
existing_ids = {chunk['id'] for chunk in self.chunks}
|
||||||
|
if passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Duplicate passage ID: {passage_id}")
|
||||||
|
|
||||||
|
# Store the definitive ID with the chunk
|
||||||
|
chunk_data = {
|
||||||
|
"id": passage_id,
|
||||||
|
"text": text,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
self.chunks.append(chunk_data)
|
||||||
|
|
||||||
def build_index(self, index_path: str):
|
def build_index(self, index_path: str):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
@@ -92,28 +147,65 @@ class LeannBuilder:
|
|||||||
self.dimensions = _get_embedding_dimensions(self.embedding_model)
|
self.dimensions = _get_embedding_dimensions(self.embedding_model)
|
||||||
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
|
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
|
||||||
|
|
||||||
|
path = Path(index_path)
|
||||||
|
index_dir = path.parent
|
||||||
|
index_name = path.name
|
||||||
|
|
||||||
|
# Ensure the directory exists
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create the passages.jsonl file and offset index
|
||||||
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||||
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||||
|
|
||||||
|
offset_map = {}
|
||||||
|
|
||||||
|
with open(passages_file, 'w', encoding='utf-8') as f:
|
||||||
|
for chunk in self.chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
passage_data = {
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk["metadata"]
|
||||||
|
}
|
||||||
|
json.dump(passage_data, f, ensure_ascii=False)
|
||||||
|
f.write('\n')
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
# Save the offset map
|
||||||
|
with open(offset_file, 'wb') as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
texts_to_embed = [c["text"] for c in self.chunks]
|
texts_to_embed = [c["text"] for c in self.chunks]
|
||||||
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
|
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
|
||||||
|
|
||||||
|
# Extract string IDs for the backend
|
||||||
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
|
|
||||||
|
# Build the vector index
|
||||||
current_backend_kwargs = self.backend_kwargs.copy()
|
current_backend_kwargs = self.backend_kwargs.copy()
|
||||||
current_backend_kwargs['dimensions'] = self.dimensions
|
current_backend_kwargs['dimensions'] = self.dimensions
|
||||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
|
|
||||||
build_kwargs = current_backend_kwargs.copy()
|
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||||
build_kwargs['chunks'] = self.chunks
|
|
||||||
builder_instance.build(embeddings, index_path, **build_kwargs)
|
|
||||||
|
|
||||||
index_dir = Path(index_path).parent
|
# Create the lightweight meta.json file
|
||||||
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json"
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"version": "0.1.0",
|
"version": "1.0",
|
||||||
"backend_name": self.backend_name,
|
"backend_name": self.backend_name,
|
||||||
"embedding_model": self.embedding_model,
|
"embedding_model": self.embedding_model,
|
||||||
"dimensions": self.dimensions,
|
"dimensions": self.dimensions,
|
||||||
"backend_kwargs": self.backend_kwargs,
|
"backend_kwargs": self.backend_kwargs,
|
||||||
"num_chunks": len(self.chunks),
|
"passage_sources": [
|
||||||
"chunks": self.chunks,
|
{
|
||||||
|
"type": "jsonl",
|
||||||
|
"path": str(passages_file),
|
||||||
|
"index_path": str(offset_file)
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
with open(leann_meta_path, 'w', encoding='utf-8') as f:
|
with open(leann_meta_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
@@ -136,14 +228,16 @@ class LeannSearcher:
|
|||||||
backend_name = self.meta_data['backend_name']
|
backend_name = self.meta_data['backend_name']
|
||||||
self.embedding_model = self.meta_data['embedding_model']
|
self.embedding_model = self.meta_data['embedding_model']
|
||||||
|
|
||||||
|
# Initialize the passage manager
|
||||||
|
passage_sources = self.meta_data.get('passage_sources', [])
|
||||||
|
self.passage_manager = PassageManager(passage_sources)
|
||||||
|
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
|
||||||
|
|
||||||
final_kwargs = self.meta_data.get("backend_kwargs", {})
|
final_kwargs = backend_kwargs.copy()
|
||||||
final_kwargs.update(backend_kwargs)
|
final_kwargs['meta'] = self.meta_data
|
||||||
if 'dimensions' not in final_kwargs:
|
|
||||||
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
|
|
||||||
|
|
||||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||||
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
|
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
|
||||||
@@ -155,15 +249,17 @@ class LeannSearcher:
|
|||||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
for label, dist in zip(results['labels'][0], results['distances'][0]):
|
for string_id, dist in zip(results['labels'][0], results['distances'][0]):
|
||||||
if label < len(self.meta_data['chunks']):
|
try:
|
||||||
chunk_info = self.meta_data['chunks'][label]
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
enriched_results.append(SearchResult(
|
enriched_results.append(SearchResult(
|
||||||
id=label,
|
id=string_id,
|
||||||
score=dist,
|
score=dist,
|
||||||
text=chunk_info['text'],
|
text=passage_data['text'],
|
||||||
metadata=chunk_info.get('metadata', {})
|
metadata=passage_data.get('metadata', {})
|
||||||
))
|
))
|
||||||
|
except KeyError:
|
||||||
|
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
# packages/leann-core/src/leann/registry.py
|
# packages/leann-core/src/leann/registry.py
|
||||||
|
|
||||||
from typing import Dict, TYPE_CHECKING
|
from typing import Dict, TYPE_CHECKING
|
||||||
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
@@ -12,4 +15,22 @@ def register_backend(name: str):
|
|||||||
print(f"INFO: Registering backend '{name}'")
|
print(f"INFO: Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def autodiscover_backends():
|
||||||
|
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||||
|
print("INFO: Starting backend auto-discovery...")
|
||||||
|
discovered_backends = []
|
||||||
|
for dist in importlib.metadata.distributions():
|
||||||
|
dist_name = dist.metadata['name']
|
||||||
|
if dist_name.startswith('leann-backend-'):
|
||||||
|
backend_module_name = dist_name.replace('-', '_')
|
||||||
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
|
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||||
|
try:
|
||||||
|
importlib.import_module(backend_module_name)
|
||||||
|
# Registration message is printed by the decorator
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
|
print("INFO: Backend auto-discovery finished.")
|
||||||
File diff suppressed because it is too large
Load Diff
68
tests/sanity_checks/README_hnsw_pruning.md
Normal file
68
tests/sanity_checks/README_hnsw_pruning.md
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# HNSW Index Storage Optimization
|
||||||
|
|
||||||
|
This document explains the storage optimization features available in the HNSW backend.
|
||||||
|
|
||||||
|
## Storage Modes
|
||||||
|
|
||||||
|
The HNSW backend supports two orthogonal optimization techniques:
|
||||||
|
|
||||||
|
### 1. CSR Compression (`is_compact=True`)
|
||||||
|
- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format
|
||||||
|
- Reduces memory overhead from graph adjacency storage
|
||||||
|
- Maintains all embedding data for direct access
|
||||||
|
|
||||||
|
### 2. Embedding Pruning (`is_recompute=True`)
|
||||||
|
- Removes embedding vectors from the index file
|
||||||
|
- Replaces them with a NULL storage marker
|
||||||
|
- Requires recomputation via embedding server during search
|
||||||
|
- Must be used with `is_compact=True` for efficiency
|
||||||
|
|
||||||
|
## Performance Impact
|
||||||
|
|
||||||
|
**Storage Reduction (100 vectors, 384 dimensions):**
|
||||||
|
```
|
||||||
|
Standard format: 168 KB (embeddings + graph)
|
||||||
|
CSR only: 160 KB (embeddings + compressed graph)
|
||||||
|
CSR + Pruned: 6 KB (compressed graph only)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Benefits:**
|
||||||
|
- **CSR compression**: ~5% size reduction from graph optimization
|
||||||
|
- **Embedding pruning**: ~95% size reduction by removing embeddings
|
||||||
|
- **Combined**: Up to 96% total storage reduction
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Standard format (largest)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# CSR compressed (medium)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# CSR + Pruned (smallest, requires embedding server)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
is_compact=True, # Required for pruning
|
||||||
|
is_recompute=True # Default: enabled
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Trade-offs
|
||||||
|
|
||||||
|
| Mode | Storage | Search Speed | Memory Usage | Setup Complexity |
|
||||||
|
|------|---------|--------------|--------------|------------------|
|
||||||
|
| Standard | Largest | Fastest | Highest | Simple |
|
||||||
|
| CSR | Medium | Fast | Medium | Simple |
|
||||||
|
| CSR + Pruned | Smallest | Slower* | Lowest | Complex** |
|
||||||
|
|
||||||
|
*Requires network round-trip to embedding server for recomputation
|
||||||
|
**Needs embedding server and passages file for search
|
||||||
156
tests/sanity_checks/test_hnsw_pruning.py
Normal file
156
tests/sanity_checks/test_hnsw_pruning.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Sanity check script to verify HNSW index pruning effectiveness.
|
||||||
|
Tests the difference in file sizes between pruned and non-pruned indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Add the project root to the Python path
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
# Import backend packages to trigger plugin registration
|
||||||
|
import leann_backend_hnsw
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
def create_sample_documents(num_docs=1000):
|
||||||
|
"""Create sample documents for testing"""
|
||||||
|
documents = []
|
||||||
|
for i in range(num_docs):
|
||||||
|
documents.append(f"Sample document {i} with some random text content for testing purposes.")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def build_index(documents, output_dir, is_recompute=True):
|
||||||
|
"""Build HNSW index with specified recompute setting"""
|
||||||
|
index_path = os.path.join(output_dir, "test_index.hnsw")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
M=16,
|
||||||
|
efConstruction=100,
|
||||||
|
distance_metric="mips",
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=is_recompute
|
||||||
|
)
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
builder.add_text(doc)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def get_file_size(filepath):
|
||||||
|
"""Get file size in bytes"""
|
||||||
|
return os.path.getsize(filepath)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("🔍 HNSW Pruning Sanity Check")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Create sample data
|
||||||
|
print("📊 Creating sample documents...")
|
||||||
|
documents = create_sample_documents(num_docs=1000)
|
||||||
|
print(f" Number of documents: {len(documents)}")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
print(f"📁 Working in temporary directory: {temp_dir}")
|
||||||
|
|
||||||
|
# Build index with pruning (is_recompute=True)
|
||||||
|
print("\n🔨 Building index with pruning enabled (is_recompute=True)...")
|
||||||
|
pruned_dir = os.path.join(temp_dir, "pruned")
|
||||||
|
os.makedirs(pruned_dir, exist_ok=True)
|
||||||
|
|
||||||
|
pruned_index_path = build_index(documents, pruned_dir, is_recompute=True)
|
||||||
|
# Check what files were actually created
|
||||||
|
print(f" Looking for index files at: {pruned_index_path}")
|
||||||
|
import glob
|
||||||
|
files = glob.glob(f"{pruned_index_path}*")
|
||||||
|
print(f" Found files: {files}")
|
||||||
|
|
||||||
|
# Try to find the actual index file
|
||||||
|
if os.path.exists(f"{pruned_index_path}.index"):
|
||||||
|
pruned_index_file = f"{pruned_index_path}.index"
|
||||||
|
else:
|
||||||
|
# Look for any .index file in the directory
|
||||||
|
index_files = glob.glob(f"{pruned_dir}/*.index")
|
||||||
|
if index_files:
|
||||||
|
pruned_index_file = index_files[0]
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"No .index file found in {pruned_dir}")
|
||||||
|
|
||||||
|
pruned_size = get_file_size(pruned_index_file)
|
||||||
|
print(f" ✅ Pruned index built successfully")
|
||||||
|
print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
|
||||||
|
|
||||||
|
# Build index without pruning (is_recompute=False)
|
||||||
|
print("\n🔨 Building index without pruning (is_recompute=False)...")
|
||||||
|
non_pruned_dir = os.path.join(temp_dir, "non_pruned")
|
||||||
|
os.makedirs(non_pruned_dir, exist_ok=True)
|
||||||
|
|
||||||
|
non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False)
|
||||||
|
# Check what files were actually created
|
||||||
|
print(f" Looking for index files at: {non_pruned_index_path}")
|
||||||
|
files = glob.glob(f"{non_pruned_index_path}*")
|
||||||
|
print(f" Found files: {files}")
|
||||||
|
|
||||||
|
# Try to find the actual index file
|
||||||
|
if os.path.exists(f"{non_pruned_index_path}.index"):
|
||||||
|
non_pruned_index_file = f"{non_pruned_index_path}.index"
|
||||||
|
else:
|
||||||
|
# Look for any .index file in the directory
|
||||||
|
index_files = glob.glob(f"{non_pruned_dir}/*.index")
|
||||||
|
if index_files:
|
||||||
|
non_pruned_index_file = index_files[0]
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"No .index file found in {non_pruned_dir}")
|
||||||
|
|
||||||
|
non_pruned_size = get_file_size(non_pruned_index_file)
|
||||||
|
print(f" ✅ Non-pruned index built successfully")
|
||||||
|
print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
|
||||||
|
|
||||||
|
# Compare sizes
|
||||||
|
print("\n📊 Comparison Results:")
|
||||||
|
print("=" * 30)
|
||||||
|
size_diff = non_pruned_size - pruned_size
|
||||||
|
size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0
|
||||||
|
reduction_percent = (1 - size_ratio) * 100
|
||||||
|
|
||||||
|
print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
|
||||||
|
print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
|
||||||
|
print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)")
|
||||||
|
print(f"Size ratio: {size_ratio:.3f}")
|
||||||
|
print(f"Size reduction: {reduction_percent:.1f}%")
|
||||||
|
|
||||||
|
# Verify pruning effectiveness
|
||||||
|
print("\n🔍 Verification:")
|
||||||
|
if size_diff > 0:
|
||||||
|
print(" ✅ Pruning is effective - pruned index is smaller")
|
||||||
|
if reduction_percent > 10:
|
||||||
|
print(f" ✅ Significant size reduction: {reduction_percent:.1f}%")
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%")
|
||||||
|
else:
|
||||||
|
print(" ❌ Pruning appears ineffective - no size reduction")
|
||||||
|
|
||||||
|
# Check if passages files were created
|
||||||
|
pruned_passages = f"{pruned_index_path}.passages.json"
|
||||||
|
non_pruned_passages = f"{non_pruned_index_path}.passages.json"
|
||||||
|
|
||||||
|
print(f"\n📄 Passages files:")
|
||||||
|
print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}")
|
||||||
|
print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
Reference in New Issue
Block a user