Merge remote-tracking branch 'origin/main' into datastore-reproduce
This commit is contained in:
@@ -2,6 +2,33 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
endif()
|
||||
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
|
||||
# Configure msgpack-c - disable boost dependency
|
||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
@@ -77,17 +78,29 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self._convert_to_csr(index_file)
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
success = convert_hnsw_graph_to_csr(
|
||||
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
|
||||
str(index_file),
|
||||
str(csr_temp_file),
|
||||
prune_embeddings=self.is_recompute
|
||||
)
|
||||
|
||||
if success:
|
||||
import shutil
|
||||
print("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||
else:
|
||||
# Clean up and fail fast
|
||||
if csr_temp_file.exists():
|
||||
os.remove(csr_temp_file)
|
||||
raise RuntimeError("CSR conversion failed")
|
||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||
|
||||
class HNSWSearcher(BaseSearcher):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
@@ -99,7 +112,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
|
||||
self.is_compact, self.is_pruned = (
|
||||
self.meta.get('is_compact', True),
|
||||
self.meta.get('is_pruned', True)
|
||||
)
|
||||
|
||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||
if not index_file.exists():
|
||||
@@ -114,11 +130,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
|
||||
is_compact = self.meta.get('is_compact', True)
|
||||
is_pruned = self.meta.get('is_pruned', True)
|
||||
return is_compact, is_pruned
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||
from . import faiss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user