Compare commits

...

4 Commits

Author SHA1 Message Date
Andy Lee
16705fc44a refactor: passage structure 2025-07-06 21:48:38 +00:00
Andy Lee
5611f708e9 docs: embedding pruning 2025-07-06 19:50:01 +00:00
Andy Lee
b4ae57b2c0 feat: auto discovery of packages and fix passage gen for diskann 2025-07-06 05:05:49 +00:00
Andy Lee
5659174635 fix: diskann zmq port and passages 2025-07-06 04:14:15 +00:00
15 changed files with 629 additions and 9748 deletions

3
.gitignore vendored
View File

@@ -71,3 +71,6 @@ test_indices*/
test_*.py test_*.py
!tests/** !tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json

View File

@@ -74,7 +74,7 @@ def main():
print(f"⏱️ Basic search time: {basic_time:.3f} seconds") print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<") print(">>> Basic search results <<<")
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo --- # --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...") print(f"\n[PHASE 3] Recompute search using embedding server...")
@@ -107,7 +107,7 @@ def main():
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds") print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<") print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1): for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results # Compare results
print(f"\n--- Result comparison ---") print(f"\n--- Result comparison ---")
@@ -116,8 +116,8 @@ def main():
print("\nBasic search vs Recompute results:") print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))): for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score'] basic_score = results[i].score
recompute_score = recompute_results[i]['score'] recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score) score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}") print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")

View File

@@ -10,7 +10,6 @@ import asyncio
import os import os
import dotenv import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_hnsw # Import to ensure backend registration
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -39,7 +38,7 @@ all_texts = []
for doc in documents: for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.text) all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index") INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
# CSR compact mode with recompute # CSR compact mode with recompute
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="diskann",
embedding_model="facebook/contriever", embedding_model="facebook/contriever",
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
@@ -74,7 +73,7 @@ async def main():
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}") print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever") chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}") print(f"Leann: {chat_response}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Any, List
import contextlib import contextlib
import threading import threading
import time import time
@@ -11,6 +11,7 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_backend from leann.registry import register_backend
@@ -19,13 +20,13 @@ from leann.interface import (
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendSearcherInterface
) )
from . import _diskannpy as diskannpy def _get_diskann_metrics():
from . import _diskannpy as diskannpy
METRIC_MAP = { return {
"mips": diskannpy.Metric.INNER_PRODUCT, "mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2, "l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE, "cosine": diskannpy.Metric.COSINE,
} }
@contextlib.contextmanager @contextlib.contextmanager
def chdir(path): def chdir(path):
@@ -36,7 +37,7 @@ def chdir(path):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str): def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape num_vectors, dim = data.shape
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors)) f.write(struct.pack('I', num_vectors))
@@ -67,7 +68,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def build(self, data: np.ndarray, index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
@@ -82,8 +84,15 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower() metric_str = build_kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(metric_str) metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
@@ -99,6 +108,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try: try:
from . import _diskannpy as diskannpy
with chdir(index_dir): with chdir(index_dir):
diskannpy.build_disk_float_index( diskannpy.build_disk_float_index(
metric_enum, metric_enum,
@@ -127,30 +137,38 @@ class DiskannSearcher(LeannBackendSearcherInterface):
if not self.meta: if not self.meta:
raise ValueError("DiskannSearcher requires metadata from .meta.json.") raise ValueError("DiskannSearcher requires metadata from .meta.json.")
dimensions = self.meta.get("dimensions")
if not dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model") self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model: if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path) path = Path(index_path)
index_dir = path.parent self.index_dir = path.parent
index_prefix = path.stem self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
# Extract parameters for DiskANN
distance_metric = kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
num_threads = kwargs.get("num_threads", 8) num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.zmq_port = kwargs.get("zmq_port", 6666)
try: try:
full_index_prefix = str(index_dir / index_prefix) from . import _diskannpy as diskannpy
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", "" metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
) )
self.num_threads = num_threads self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
@@ -161,7 +179,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}") print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256) complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4) beam_width = kwargs.get("beam_width", 4)
@@ -172,23 +190,32 @@ class DiskannSearcher(LeannBackendSearcherInterface):
prune_ratio = kwargs.get("prune_ratio", 0.0) prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False) batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False) global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings: if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model: if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.") raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
zmq_port = kwargs.get("zmq_port", 6666) passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
server_started = self.embedding_server_manager.start_server( server_started = self.embedding_server_manager.start_server(
port=zmq_port, port=self.zmq_port,
model_name=self.embedding_model, model_name=self.embedding_model,
distance_metric=self.distance_metric distance_metric=kwargs.get("distance_metric", "mips"),
passages_file=passages_file
) )
if not server_started: if not server_started:
print(f"WARNING: Failed to start embedding server, falling back to PQ computation") raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
recompute_beighbor_embeddings = False
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
@@ -211,11 +238,23 @@ class DiskannSearcher(LeannBackendSearcherInterface):
batch_recompute, batch_recompute,
global_pruning global_pruning
) )
return {"labels": labels, "distances": distances}
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e: except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}") print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0] batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64), return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)} "distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self): def __del__(self):

View File

@@ -7,6 +7,8 @@ import pickle
import argparse import argparse
import threading import threading
import time import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import os import os
@@ -17,49 +19,76 @@ import numpy as np
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages # --- New Passage Loader from HNSW backend ---
class SimpleDocumentStore: class SimplePassageLoader:
"""简化的文档存储支持任意ID""" """
def __init__(self, documents: dict = None): Simple passage loader that replaces config.py dependencies
self.documents = documents or {} """
# 默认演示文档 def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.default_docs = { self.passages_data = passages_data or {}
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
def __getitem__(self, doc_id): def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
doc_id = int(doc_id) """Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
# 优先使用指定的文档 def __len__(self) -> int:
if doc_id in self.documents: return len(self.passages_data)
return {"text": self.documents[doc_id]}
# 其次使用默认演示文档 def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
if doc_id in self.default_docs: """
return {"text": self.default_docs[doc_id]} Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
from pathlib import Path
import pickle
# 对于任意其他ID返回通用文档 if not os.path.exists(passages_file):
fallback_docs = [ raise FileNotFoundError(f"Passages file {passages_file} not found.")
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档 if not passages_file.endswith('.jsonl'):
fallback_text = fallback_docs[doc_id % len(fallback_docs)] raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __len__(self): # Load label map (int -> string_id)
return len(self.documents) + len(self.default_docs) passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread( def create_embedding_server_thread(
zmq_port=5555, zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128, max_batch_size=128,
passages_file: Optional[str] = None,
): ):
""" """
在当前线程中创建并运行 embedding server 在当前线程中创建并运行 embedding server
@@ -109,15 +138,14 @@ def create_embedding_server_thread(
except Exception as e: except Exception as e:
print(f"WARNING: Model optimization failed: {e}") print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档 # Load passages from file if provided
demo_documents = { if passages_file and os.path.exists(passages_file):
0: "Python is a high-level, interpreted language known for simplicity.", passages = load_passages_from_file(passages_file)
1: "Machine learning builds systems that learn from data.", else:
2: "Data structures like arrays, lists, and graphs organize data.", print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
} passages = SimplePassageLoader()
passages = SimpleDocumentStore(demo_documents) print(f"INFO: Loaded {len(passages)} passages.")
print(f"INFO: Loaded {len(passages)} demo documents")
class DeviceTimer: class DeviceTimer:
"""设备计时器""" """设备计时器"""
@@ -264,7 +292,13 @@ def create_embedding_server_thread(
for nid in node_ids: for nid in node_ids:
txtinfo = passages[nid] txtinfo = passages[nid]
txt = txtinfo["text"] txt = txtinfo["text"]
texts.append(txt) if txt:
texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed() lookup_timer.print_elapsed()
if missing_ids: if missing_ids:
@@ -360,18 +394,20 @@ def create_embedding_server(
max_batch_size=128, max_batch_size=128,
lazy_load_passages=False, lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
): ):
""" """
原有的 create_embedding_server 函数保持不变 原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行 这个是阻塞版本,用于直接运行
""" """
create_embedding_server_thread(zmq_port, model_name, max_batch_size) create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service") parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name") parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True) parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False) parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-fp16", action="store_true", default=False)
@@ -394,4 +430,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages, lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name, model_name=args.model_name,
passages_file=args.passages_file,
) )

View File

@@ -468,16 +468,27 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
# --- Write CSR HNSW graph data using unified function --- # --- Write CSR HNSW graph data using unified function ---
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...") print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
# Determine storage fourcc based on prune_embeddings # Determine storage fourcc and data based on prune_embeddings
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
if prune_embeddings: if prune_embeddings:
print(f" Pruning embeddings: Writing NULL storage marker.") print(f" Pruning embeddings: Writing NULL storage marker.")
storage_data = b'' output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b''
else:
# Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(f" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data")
else:
print(f" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b''
# Use the unified write function # Use the unified write function
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np, levels_np, compact_level_ptr, compact_node_offsets_np,
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'') compact_neighbors_data, output_storage_fourcc, storage_data)
# Clean up memory # Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np del assign_probas_np, cum_nneighbor_per_level_np, levels_np

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any, List
import contextlib import contextlib
import threading import threading
import time import time
@@ -11,6 +11,7 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager from leann.embedding_server_manager import EmbeddingServerManager
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -74,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_recompute and not self.is_compact: if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency") raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS""" """Build HNSW index using FAISS"""
from . import faiss from . import faiss
@@ -89,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if not data.flags['C_CONTIGUOUS']: if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data) data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower() metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str) metric_enum = get_metric_map().get(metric_str)
if metric_enum is None: if metric_enum is None:
@@ -119,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e: except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}") print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise raise
@@ -155,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
print(f"💥 ERROR: CSR conversion failed. Exception: {e}") print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface): class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]: def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
@@ -282,6 +262,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self.index_dir = path.parent self.index_dir = path.parent
self.index_prefix = path.stem self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index" index_file = self.index_dir / f"{self.index_prefix}.index"
if not index_file.exists(): if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}") raise FileNotFoundError(f"HNSW index file not found at {index_file}")
@@ -303,7 +291,8 @@ class HNSWSearcher(LeannBackendSearcherInterface):
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0) hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path") hnsw_config.external_storage_path = kwargs.get("external_storage_path")
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute: if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
@@ -335,12 +324,13 @@ class HNSWSearcher(LeannBackendSearcherInterface):
passages_file = kwargs.get("passages_file") passages_file = kwargs.get("passages_file")
if not passages_file: if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" # Get the passages file path from meta.json
if potential_passages_file.exists(): if 'passage_sources' in self.meta and self.meta['passage_sources']:
passages_file = str(potential_passages_file) passage_source = self.meta['passage_sources'][0]
print(f"INFO: Automatically found passages file: {passages_file}") passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else: else:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.") raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
zmq_port = kwargs.get("zmq_port", 5557) zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server( server_started = self.embedding_server_manager.start_server(
@@ -361,15 +351,28 @@ class HNSWSearcher(LeannBackendSearcherInterface):
faiss.normalize_L2(query) faiss.normalize_L2(query)
try: try:
self._index.hnsw.efSearch = ef params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
batch_size = query.shape[0] batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32) distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64) labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels)) self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances} # Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e: except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}") print(f"💥 ERROR: HNSW search failed. Exception: {e}")

View File

@@ -58,21 +58,46 @@ class SimplePassageLoader:
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
""" """
Load passages from a JSON file Load passages from a JSONL file with label map support
Expected format: {"passage_id": "passage_text", ...} Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
""" """
if not os.path.exists(passages_file): if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.") raise FileNotFoundError(f"Passages file {passages_file} not found.")
return SimplePassageLoader()
try: if not passages_file.endswith('.jsonl'):
with open(passages_file, 'r', encoding='utf-8') as f: raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}") # Load label map (int -> string_id)
return SimplePassageLoader(passages_data) passages_dir = Path(passages_file).parent
except Exception as e: label_map_file = passages_dir / "leann.labels.map"
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader() label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,

View File

@@ -0,0 +1,7 @@
# packages/leann-core/src/leann/__init__.py
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends()
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]

View File

@@ -7,6 +7,8 @@ import json
from pathlib import Path from pathlib import Path
import openai import openai
from dataclasses import dataclass, field from dataclasses import dataclass, field
import uuid
import pickle
# --- Helper Functions for Embeddings --- # --- Helper Functions for Embeddings ---
@@ -56,11 +58,45 @@ def _get_embedding_dimensions(model_name: str) -> int:
@dataclass @dataclass
class SearchResult: class SearchResult:
"""Represents a single search result.""" """Represents a single search result."""
id: int id: str
score: float score: float
text: str text: str
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
"""Manages passage data and lazy loading from JSONL files."""
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not os.path.exists(index_file):
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
def get_passage(self, passage_id: str) -> Dict[str, Any]:
"""Lazy load a passage by ID."""
for passage_file, offset_map in self.offset_maps.items():
if passage_id in offset_map:
offset = offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
line = f.readline()
return json.loads(line)
raise KeyError(f"Passage ID not found: {passage_id}")
# --- Core Classes --- # --- Core Classes ---
class LeannBuilder: class LeannBuilder:
@@ -82,7 +118,26 @@ class LeannBuilder:
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.") print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
self.chunks.append({"text": text, "metadata": metadata or {}}) if metadata is None:
metadata = {}
# Check if ID is provided in metadata
passage_id = metadata.get('id')
if passage_id is None:
passage_id = str(uuid.uuid4())
else:
# Validate uniqueness
existing_ids = {chunk['id'] for chunk in self.chunks}
if passage_id in existing_ids:
raise ValueError(f"Duplicate passage ID: {passage_id}")
# Store the definitive ID with the chunk
chunk_data = {
"id": passage_id,
"text": text,
"metadata": metadata
}
self.chunks.append(chunk_data)
def build_index(self, index_path: str): def build_index(self, index_path: str):
if not self.chunks: if not self.chunks:
@@ -92,28 +147,65 @@ class LeannBuilder:
self.dimensions = _get_embedding_dimensions(self.embedding_model) self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}") print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
# Ensure the directory exists
index_dir.mkdir(parents=True, exist_ok=True)
# Create the passages.jsonl file and offset index
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
passage_data = {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"]
}
json.dump(passage_data, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
# Save the offset map
with open(offset_file, 'wb') as f:
pickle.dump(offset_map, f)
# Compute embeddings
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
# Extract string IDs for the backend
string_ids = [chunk["id"] for chunk in self.chunks]
# Build the vector index
current_backend_kwargs = self.backend_kwargs.copy() current_backend_kwargs = self.backend_kwargs.copy()
current_backend_kwargs['dimensions'] = self.dimensions current_backend_kwargs['dimensions'] = self.dimensions
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
build_kwargs = current_backend_kwargs.copy() builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
build_kwargs['chunks'] = self.chunks
builder_instance.build(embeddings, index_path, **build_kwargs)
index_dir = Path(index_path).parent # Create the lightweight meta.json file
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "0.1.0", "version": "1.0",
"backend_name": self.backend_name, "backend_name": self.backend_name,
"embedding_model": self.embedding_model, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs, "backend_kwargs": self.backend_kwargs,
"num_chunks": len(self.chunks), "passage_sources": [
"chunks": self.chunks, {
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file)
}
]
} }
with open(leann_meta_path, 'w', encoding='utf-8') as f: with open(leann_meta_path, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
@@ -136,14 +228,16 @@ class LeannSearcher:
backend_name = self.meta_data['backend_name'] backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model'] self.embedding_model = self.meta_data['embedding_model']
# Initialize the passage manager
passage_sources = self.meta_data.get('passage_sources', [])
self.passage_manager = PassageManager(passage_sources)
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.") raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
final_kwargs = self.meta_data.get("backend_kwargs", {}) final_kwargs = backend_kwargs.copy()
final_kwargs.update(backend_kwargs) final_kwargs['meta'] = self.meta_data
if 'dimensions' not in final_kwargs:
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.") print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
@@ -155,15 +249,17 @@ class LeannSearcher:
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
enriched_results = [] enriched_results = []
for label, dist in zip(results['labels'][0], results['distances'][0]): for string_id, dist in zip(results['labels'][0], results['distances'][0]):
if label < len(self.meta_data['chunks']): try:
chunk_info = self.meta_data['chunks'][label] passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult( enriched_results.append(SearchResult(
id=label, id=string_id,
score=dist, score=dist,
text=chunk_info['text'], text=passage_data['text'],
metadata=chunk_info.get('metadata', {}) metadata=passage_data.get('metadata', {})
)) ))
except KeyError:
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
return enriched_results return enriched_results

View File

@@ -1,6 +1,9 @@
# packages/leann-core/src/leann/registry.py # packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
if TYPE_CHECKING: if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface from leann.interface import LeannBackendFactoryInterface
@@ -13,3 +16,21 @@ def register_backend(name: str):
BACKEND_REGISTRY[name] = cls BACKEND_REGISTRY[name] = cls
return cls return cls
return decorator return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
# HNSW Index Storage Optimization
This document explains the storage optimization features available in the HNSW backend.
## Storage Modes
The HNSW backend supports two orthogonal optimization techniques:
### 1. CSR Compression (`is_compact=True`)
- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format
- Reduces memory overhead from graph adjacency storage
- Maintains all embedding data for direct access
### 2. Embedding Pruning (`is_recompute=True`)
- Removes embedding vectors from the index file
- Replaces them with a NULL storage marker
- Requires recomputation via embedding server during search
- Must be used with `is_compact=True` for efficiency
## Performance Impact
**Storage Reduction (100 vectors, 384 dimensions):**
```
Standard format: 168 KB (embeddings + graph)
CSR only: 160 KB (embeddings + compressed graph)
CSR + Pruned: 6 KB (compressed graph only)
```
**Key Benefits:**
- **CSR compression**: ~5% size reduction from graph optimization
- **Embedding pruning**: ~95% size reduction by removing embeddings
- **Combined**: Up to 96% total storage reduction
## Usage
```python
# Standard format (largest)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=False,
is_recompute=False
)
# CSR compressed (medium)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True,
is_recompute=False
)
# CSR + Pruned (smallest, requires embedding server)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True, # Required for pruning
is_recompute=True # Default: enabled
)
```
## Trade-offs
| Mode | Storage | Search Speed | Memory Usage | Setup Complexity |
|------|---------|--------------|--------------|------------------|
| Standard | Largest | Fastest | Highest | Simple |
| CSR | Medium | Fast | Medium | Simple |
| CSR + Pruned | Smallest | Slower* | Lowest | Complex** |
*Requires network round-trip to embedding server for recomputation
**Needs embedding server and passages file for search

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Sanity check script to verify HNSW index pruning effectiveness.
Tests the difference in file sizes between pruned and non-pruned indices.
"""
import os
import sys
import tempfile
from pathlib import Path
import numpy as np
import json
# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
# Import backend packages to trigger plugin registration
import leann_backend_hnsw
from leann.api import LeannBuilder
def create_sample_documents(num_docs=1000):
"""Create sample documents for testing"""
documents = []
for i in range(num_docs):
documents.append(f"Sample document {i} with some random text content for testing purposes.")
return documents
def build_index(documents, output_dir, is_recompute=True):
"""Build HNSW index with specified recompute setting"""
index_path = os.path.join(output_dir, "test_index.hnsw")
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
M=16,
efConstruction=100,
distance_metric="mips",
is_compact=True,
is_recompute=is_recompute
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
return index_path
def get_file_size(filepath):
"""Get file size in bytes"""
return os.path.getsize(filepath)
def main():
print("🔍 HNSW Pruning Sanity Check")
print("=" * 50)
# Create sample data
print("📊 Creating sample documents...")
documents = create_sample_documents(num_docs=1000)
print(f" Number of documents: {len(documents)}")
with tempfile.TemporaryDirectory() as temp_dir:
print(f"📁 Working in temporary directory: {temp_dir}")
# Build index with pruning (is_recompute=True)
print("\n🔨 Building index with pruning enabled (is_recompute=True)...")
pruned_dir = os.path.join(temp_dir, "pruned")
os.makedirs(pruned_dir, exist_ok=True)
pruned_index_path = build_index(documents, pruned_dir, is_recompute=True)
# Check what files were actually created
print(f" Looking for index files at: {pruned_index_path}")
import glob
files = glob.glob(f"{pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{pruned_index_path}.index"):
pruned_index_file = f"{pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{pruned_dir}/*.index")
if index_files:
pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {pruned_dir}")
pruned_size = get_file_size(pruned_index_file)
print(f" ✅ Pruned index built successfully")
print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
# Build index without pruning (is_recompute=False)
print("\n🔨 Building index without pruning (is_recompute=False)...")
non_pruned_dir = os.path.join(temp_dir, "non_pruned")
os.makedirs(non_pruned_dir, exist_ok=True)
non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False)
# Check what files were actually created
print(f" Looking for index files at: {non_pruned_index_path}")
files = glob.glob(f"{non_pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{non_pruned_index_path}.index"):
non_pruned_index_file = f"{non_pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{non_pruned_dir}/*.index")
if index_files:
non_pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {non_pruned_dir}")
non_pruned_size = get_file_size(non_pruned_index_file)
print(f" ✅ Non-pruned index built successfully")
print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
# Compare sizes
print("\n📊 Comparison Results:")
print("=" * 30)
size_diff = non_pruned_size - pruned_size
size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0
reduction_percent = (1 - size_ratio) * 100
print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)")
print(f"Size ratio: {size_ratio:.3f}")
print(f"Size reduction: {reduction_percent:.1f}%")
# Verify pruning effectiveness
print("\n🔍 Verification:")
if size_diff > 0:
print(" ✅ Pruning is effective - pruned index is smaller")
if reduction_percent > 10:
print(f" ✅ Significant size reduction: {reduction_percent:.1f}%")
else:
print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%")
else:
print(" ❌ Pruning appears ineffective - no size reduction")
# Check if passages files were created
pruned_passages = f"{pruned_index_path}.passages.json"
non_pruned_passages = f"{non_pruned_index_path}.passages.json"
print(f"\n📄 Passages files:")
print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}")
print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)