refactor: nits

This commit is contained in:
Andy Lee
2025-07-16 15:39:58 -07:00
parent 7b9406a3ea
commit 2a1a152073
4 changed files with 102 additions and 62 deletions

View File

@@ -1,6 +1,5 @@
import numpy as np import numpy as np
import os import os
import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Literal from typing import Dict, Any, List, Literal
@@ -12,17 +11,20 @@ from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface, LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendSearcherInterface,
) )
def _get_diskann_metrics(): def _get_diskann_metrics():
from . import _diskannpy as diskannpy from . import _diskannpy as diskannpy # type: ignore
return { return {
"mips": diskannpy.Metric.INNER_PRODUCT, "mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2, "l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE, "cosine": diskannpy.Metric.COSINE,
} }
@contextlib.contextmanager @contextlib.contextmanager
def chdir(path): def chdir(path):
original_dir = os.getcwd() original_dir = os.getcwd()
@@ -32,13 +34,15 @@ def chdir(path):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: Path): def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape num_vectors, dim = data.shape
with open(file_path, 'wb') as f: with open(file_path, "wb") as f:
f.write(struct.pack('I', num_vectors)) f.write(struct.pack("I", num_vectors))
f.write(struct.pack('I', dim)) f.write(struct.pack("I", dim))
f.write(data.tobytes()) f.write(data.tobytes())
@register_backend("diskann") @register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface): class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -49,6 +53,7 @@ class DiskannBackend(LeannBackendFactoryInterface):
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
return DiskannSearcher(index_path, **kwargs) return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface): class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
@@ -67,32 +72,46 @@ class DiskannBuilder(LeannBackendBuilderInterface):
label_map = {i: str_id for i, str_id in enumerate(ids)} label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map" label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f: with open(label_map_file, "wb") as f:
pickle.dump(label_map, f) pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get(build_kwargs.get("distance_metric", "mips").lower()) metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric.") raise ValueError("Unsupported distance_metric.")
try: try:
from . import _diskannpy as diskannpy from . import _diskannpy as diskannpy # type: ignore
with chdir(index_dir): with chdir(index_dir):
diskannpy.build_disk_float_index( diskannpy.build_disk_float_index(
metric_enum, data_filename, index_prefix, metric_enum,
build_kwargs.get("complexity", 64), build_kwargs.get("graph_degree", 32), data_filename,
build_kwargs.get("search_memory_maximum", 4.0), build_kwargs.get("build_memory_maximum", 8.0), index_prefix,
build_kwargs.get("num_threads", 8), build_kwargs.get("pq_disk_bytes", 0), "" build_kwargs.get("complexity", 64),
build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", 4.0),
build_kwargs.get("build_memory_maximum", 8.0),
build_kwargs.get("num_threads", 8),
build_kwargs.get("pq_disk_bytes", 0),
"",
) )
finally: finally:
temp_data_file = index_dir / data_filename temp_data_file = index_dir / data_filename
if temp_data_file.exists(): if temp_data_file.exists():
os.remove(temp_data_file) os.remove(temp_data_file)
class DiskannSearcher(BaseSearcher): class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
super().__init__(index_path, backend_module_name="leann_backend_diskann.embedding_server", **kwargs) super().__init__(
from . import _diskannpy as diskannpy index_path,
backend_module_name="leann_backend_diskann.embedding_server",
**kwargs,
)
from . import _diskannpy as diskannpy # type: ignore
distance_metric = kwargs.get("distance_metric", "mips").lower() distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric) metric_enum = _get_diskann_metrics().get(distance_metric)
@@ -104,23 +123,33 @@ class DiskannSearcher(BaseSearcher):
full_index_prefix = str(self.index_dir / self.index_path.stem) full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, self.num_threads, metric_enum,
kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", "" full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
self.zmq_port,
"",
"",
) )
def search(self, query: np.ndarray, top_k: int, def search(
complexity: int = 64, self,
beam_width: int = 1, query: np.ndarray,
prune_ratio: float = 0.0, top_k: int,
recompute_embeddings: bool = False, complexity: int = 64,
pruning_strategy: Literal["global", "local", "proportional"] = "global", beam_width: int = 1,
zmq_port: int = 5557, prune_ratio: float = 0.0,
batch_recompute: bool = False, recompute_embeddings: bool = False,
dedup_node_dis: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global",
**kwargs) -> Dict[str, Any]: zmq_port: int = 5557,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
) -> Dict[str, Any]:
""" """
Search for nearest neighbors using DiskANN index. Search for nearest neighbors using DiskANN index.
Args: Args:
query: Query vectors (B, D) where B is batch size, D is dimension query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return top_k: Number of nearest neighbors to return
@@ -130,26 +159,30 @@ class DiskannSearcher(BaseSearcher):
recompute_embeddings: Whether to fetch fresh embeddings from server recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy: pruning_strategy: PQ candidate selection strategy:
- "global": Use global pruning strategy (default) - "global": Use global pruning strategy (default)
- "local": Use local pruning strategy - "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global - "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server zmq_port: ZMQ port for embedding server
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific) batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific) dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility) **kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
Returns: Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray) Dict with 'labels' (list of lists) and 'distances' (ndarray)
""" """
# DiskANN doesn't support "proportional" strategy # DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional": if pruning_strategy == "proportional":
raise NotImplementedError("DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead.") raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
# Use recompute_embeddings parameter # Use recompute_embeddings parameter
use_recompute = recompute_embeddings use_recompute = recompute_embeddings
if use_recompute: if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists(): if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}") raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32: if query.dtype != np.float32:
@@ -162,17 +195,27 @@ class DiskannSearcher(BaseSearcher):
use_global_pruning = True use_global_pruning = True
labels, distances = self._index.batch_search( labels, distances = self._index.batch_search(
query, query.shape[0], top_k, query,
complexity, beam_width, self.num_threads, query.shape[0],
kwargs.get("USE_DEFERRED_FETCH", False), top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False), kwargs.get("skip_search_reorder", False),
use_recompute, use_recompute,
dedup_node_dis, dedup_node_dis,
prune_ratio, prune_ratio,
batch_recompute, batch_recompute,
use_global_pruning use_global_pruning,
) )
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels
]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}

View File

@@ -1,6 +1,5 @@
import numpy as np import numpy as np
import os import os
import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Literal from typing import Dict, Any, List, Literal
import pickle import pickle
@@ -18,7 +17,7 @@ from leann.interface import (
def get_metric_map(): def get_metric_map():
from . import faiss from . import faiss # type: ignore
return { return {
"mips": faiss.METRIC_INNER_PRODUCT, "mips": faiss.METRIC_INNER_PRODUCT,
@@ -49,7 +48,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.dimensions = self.build_params.get("dimensions") self.dimensions = self.build_params.get("dimensions")
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss from . import faiss # type: ignore
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
@@ -117,7 +116,7 @@ class HNSWSearcher(BaseSearcher):
backend_module_name="leann_backend_hnsw.hnsw_embedding_server", backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
**kwargs, **kwargs,
) )
from . import faiss from . import faiss # type: ignore
self.distance_metric = self.meta.get("distance_metric", "mips").lower() self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(self.distance_metric) metric_enum = get_metric_map().get(self.distance_metric)

View File

@@ -14,8 +14,7 @@ import torch
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from .chat import get_llm
# --- The Correct, Verified Embedding Logic from old_code.py ---
def compute_embeddings( def compute_embeddings(
@@ -28,7 +27,7 @@ def compute_embeddings(
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ImportError as e: except ImportError as e:
raise RuntimeError( raise RuntimeError(
f"sentence-transformers not available. Install with: pip install sentence-transformers" "sentence-transformers not available. Install with: uv pip install sentence-transformers"
) from e ) from e
# Load model using sentence-transformers # Load model using sentence-transformers
@@ -61,7 +60,7 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
from mlx_lm.utils import load from mlx_lm.utils import load
except ImportError as e: except ImportError as e:
raise RuntimeError( raise RuntimeError(
f"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm" "MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e ) from e
print( print(
@@ -75,7 +74,7 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
all_embeddings = [] all_embeddings = []
for chunk in chunks: for chunk in chunks:
# Tokenize # Tokenize
token_ids = tokenizer.encode(chunk) token_ids = tokenizer.encode(chunk) # type: ignore
# Convert to MLX array and add batch dimension # Convert to MLX array and add batch dimension
input_ids = mx.array([token_ids]) input_ids = mx.array([token_ids])
@@ -95,9 +94,6 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
return np.stack(all_embeddings) return np.stack(all_embeddings)
# --- Core API Classes (Restored and Unchanged) ---
@dataclass @dataclass
class SearchResult: class SearchResult:
id: str id: str
@@ -255,7 +251,7 @@ class LeannSearcher:
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]: def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
print(f"🔍 DEBUG LeannSearcher.search() called:") print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'") print(f" Query: '{query}'")
print(f" Top_k: {top_k}") print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}") print(f" Search kwargs: {search_kwargs}")
@@ -302,12 +298,13 @@ class LeannSearcher:
return enriched_results return enriched_results
from .chat import get_llm
class LeannChat: class LeannChat:
def __init__( def __init__(
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, enable_warmup: bool = False, **kwargs self,
index_path: str,
llm_config: Optional[Dict[str, Any]] = None,
enable_warmup: bool = False,
**kwargs,
): ):
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs) self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config) self.llm = get_llm(llm_config)

View File

@@ -1,16 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from typing import Dict, Any, Literal from typing import Dict, Any, List, Literal
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes""" """Backend interface for building indexes"""
@abstractmethod @abstractmethod
def build(self, data: np.ndarray, index_path: str, **kwargs) -> None: def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
"""Build index """Build index
Args: Args:
data: Vector data (N, D) data: Vector data (N, D)
ids: List of string IDs for each vector
index_path: Path to save index index_path: Path to save index
**kwargs: Backend-specific build parameters **kwargs: Backend-specific build parameters
""" """
@@ -47,7 +48,7 @@ class LeannBackendSearcherInterface(ABC):
beam_width: Number of parallel search paths/IO requests per iteration beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global", "local", or "proportional" pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters **kwargs: Backend-specific parameters