refactor: chat and base searcher

This commit is contained in:
Andy Lee
2025-07-11 16:34:12 +00:00
parent 8bffb1e5b8
commit 0da08fbe38
5 changed files with 353 additions and 428 deletions

View File

@@ -1,6 +1,7 @@
import faulthandler
faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser
@@ -50,7 +51,7 @@ if not INDEX_DIR.exists():
# CSR compact mode with recompute
builder = LeannBuilder(
backend_name="diskann",
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
@@ -67,14 +68,27 @@ if not INDEX_DIR.exists():
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
async def main():
async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
llm_config = {
"type": args.llm,
"model": args.model,
"host": args.host
}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
chat_response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
asyncio.run(main())
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf"], help="The LLM backend to use.")
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf).")
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -5,21 +5,16 @@ import struct
from pathlib import Path
from typing import Dict, Any, List
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy
return {
@@ -52,211 +47,87 @@ class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
# Pass essential metadata to the searcher
kwargs['meta'] = meta
return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(metric_str)
metric_enum = _get_diskann_metrics().get(build_kwargs.get("distance_metric", "mips").lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
raise ValueError(f"Unsupported distance_metric.")
complexity = build_kwargs.get("complexity", 64)
graph_degree = build_kwargs.get("graph_degree", 32)
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
from . import _diskannpy as diskannpy
with chdir(index_dir):
diskannpy.build_disk_float_index(
metric_enum,
data_filename,
index_prefix,
complexity,
graph_degree,
final_index_ram_limit,
indexing_ram_budget,
num_threads,
pq_disk_bytes,
codebook_prefix
metric_enum, data_filename, index_prefix,
build_kwargs.get("complexity", 64), build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", 4.0), build_kwargs.get("build_memory_maximum", 8.0),
build_kwargs.get("num_threads", 8), build_kwargs.get("pq_disk_bytes", 0), ""
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
class DiskannSearcher(LeannBackendSearcherInterface):
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
super().__init__(index_path, backend_module_name="leann_backend_diskann.embedding_server", **kwargs)
from . import _diskannpy as diskannpy
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.index_prefix = self.index_path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
# Extract parameters for DiskANN
distance_metric = kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(distance_metric)
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.num_threads = kwargs.get("num_threads", 8)
self.zmq_port = kwargs.get("zmq_port", 6666)
try:
from . import _diskannpy as diskannpy
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_diskann.embedding_server"
)
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, self.num_threads,
kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", ""
)
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4)
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
skip_search_reorder = kwargs.get("skip_search_reorder", False)
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
dedup_node_dis = kwargs.get("dedup_node_dis", False)
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
recompute = kwargs.get("recompute_beighbor_embeddings", False)
if recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}")
zmq_port = kwargs.get("zmq_port", self.zmq_port)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
passages_file = kwargs.get("passages_file")
if not passages_file:
# Pass the metadata file instead of a single passage file
meta_file_path = self.index_path.parent / f"{self.index_path.name}.meta.json"
if meta_file_path.exists():
passages_file = str(meta_file_path)
print(f"INFO: Using metadata file for lazy loading: {passages_file}")
else:
raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}")
server_started = self.embedding_server_manager.start_server(
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=kwargs.get("distance_metric", "mips"),
passages_file=passages_file
)
if not server_started:
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try:
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
USE_DEFERRED_FETCH,
skip_search_reorder,
recompute_beighbor_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
global_pruning
)
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()
labels, distances = self._index.batch_search(
query, query.shape[0], top_k,
kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False),
recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0),
kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False)
)
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -3,16 +3,9 @@ import os
import json
from pathlib import Path
from typing import Dict, Any, List
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
@@ -38,306 +31,120 @@ class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
kwargs['meta'] = meta
return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names ---
self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True)
# --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
self.external_storage_path = self.build_params.get("external_storage_path", None)
# --- Standard HNSW parameters ---
self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if self.is_skip_neighbors and not self.is_compact:
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
M = self.M
efConstruction = self.efConstruction
dim = self.dimensions
if not dim:
dim = data.shape[1]
dim = self.dimensions or data.shape[1]
index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
index.hnsw.efConstruction = self.efConstruction
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index.hnsw.efConstruction = efConstruction
if metric_str == "cosine":
faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
if self.distance_metric.lower() == "cosine":
faiss.normalize_L2(data)
if self.is_compact:
self._convert_to_csr(index_file)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
if self.is_compact:
self._convert_to_csr(index_file)
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
try:
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file),
str(csr_temp_file),
prune_embeddings=self.is_recompute
)
if success:
print("✅ CSR conversion successful.")
import shutil
shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
except Exception as e:
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
"""
Get storage status from metadata with sensible defaults.
Returns:
A tuple (is_compact, is_pruned).
"""
# Check if metadata has these flags
is_compact = self.meta.get('is_compact', True) # Default to compact (CSR format)
is_pruned = self.meta.get('is_pruned', True) # Default to pruned (embeddings removed)
print(f"INFO: Storage status from metadata: is_compact={is_compact}, is_pruned={is_pruned}")
return is_compact, is_pruned
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
)
if success:
import shutil
shutil.move(str(csr_temp_file), str(index_file))
else:
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed")
class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(index_path, backend_module_name="leann_backend_hnsw.hnsw_embedding_server", **kwargs)
from . import faiss
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("HNSWSearcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
# Check for embedding model override (not allowed)
if 'embedding_model' in kwargs and kwargs['embedding_model'] != self.embedding_model:
raise ValueError(f"Embedding model override not allowed. Index uses '{self.embedding_model}', but got '{kwargs['embedding_model']}'")
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index"
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
# Get storage status from metadata with user overrides
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
# Allow override of storage parameters via kwargs
if 'is_compact' in kwargs:
self.is_compact = kwargs['is_compact']
if 'is_pruned' in kwargs:
self.is_pruned = kwargs['is_pruned']
# Validate configuration constraints
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
# Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
if self.is_compact:
print("✅ Compact CSR format HNSW index loaded successfully.")
else:
print("✅ Standard HNSW index loaded successfully.")
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned but recompute is disabled.")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
is_compact = self.meta.get('is_compact', True)
is_pruned = self.meta.get('is_pruned', True)
return is_compact, is_pruned
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""Search using HNSW index with optional recompute functionality"""
from . import faiss
ef = kwargs.get("ef", 128)
if self.is_pruned:
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
passages_file = kwargs.get("passages_file")
if not passages_file:
# Pass the metadata file instead of a single passage file
meta_file_path = self.index_dir / f"{self.index_prefix}.index.meta.json"
if meta_file_path.exists():
passages_file = str(meta_file_path)
print(f"INFO: Using metadata file for lazy loading: {passages_file}")
else:
raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}")
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}")
zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server(
port=zmq_port,
model_name=self.embedding_model,
passages_file=passages_file,
distance_metric=self.distance_metric
)
if not server_started:
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
if self.distance_metric == "cosine":
faiss.normalize_L2(query)
try:
self._index.hnsw.efSearch = ef
params = faiss.SearchParametersHNSW()
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
params.efSearch = ef
params.beam_size = 2 # Match research system beam_size
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
# 🐛 DEBUG: Print raw faiss results before conversion
print(f"🔍 DEBUG HNSW Search Results:")
print(f" Query shape: {query.shape}")
print(f" Top_k: {top_k}")
print(f" Raw faiss indices: {labels[0] if len(labels) > 0 else 'No results'}")
print(f" Raw faiss distances: {distances[0] if len(distances) > 0 else 'No results'}")
# Convert integer labels to string IDs
string_labels = []
for batch_idx, batch_labels in enumerate(labels):
batch_string_labels = []
print(f" Batch {batch_idx} conversion:")
for i, int_label in enumerate(batch_labels):
if int_label in self.label_map:
string_id = self.label_map[int_label]
batch_string_labels.append(string_id)
print(f" faiss[{int_label}] -> passage_id '{string_id}' (distance: {distances[batch_idx][i]:.4f})")
else:
unknown_id = f"unknown_{int_label}"
batch_string_labels.append(unknown_id)
print(f" faiss[{int_label}] -> {unknown_id} (NOT FOUND in label_map!)")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
raise
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()
params = faiss.SearchParametersHNSW()
params.zmq_port = kwargs.get("zmq_port", 5557)
params.efSearch = kwargs.get("ef", 128)
params.beam_size = 2
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLMInterface(ABC):
"""Abstract base class for a generic Language Model (LLM) interface."""
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Sends a prompt to the LLM and returns the generated text.
Args:
prompt: The input prompt for the LLM.
**kwargs: Additional keyword arguments for the LLM backend.
Returns:
The response string from the LLM.
"""
pass
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
except ImportError:
raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.")
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": kwargs
}
logger.info(f"Sending request to Ollama: {payload}")
try:
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
# The response from Ollama can be a stream of JSON objects, handle this
response_parts = response.text.strip().split('\n')
full_response = ""
for part in response_parts:
if part:
json_part = json.loads(part)
full_response += json_part.get("response", "")
if json_part.get("done"):
break
return full_response
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return f"Error: Could not get a response from Ollama. Details: {e}"
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
try:
from transformers import pipeline
except ImportError:
raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.")
self.pipeline = pipeline("text-generation", model=model_name)
def ask(self, prompt: str, **kwargs) -> str:
# Sensible defaults for text generation
params = {
"max_length": 500,
"num_return_sequences": 1,
**kwargs
}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
return results[0]['generated_text']
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
def ask(self, prompt: str, **kwargs) -> str:
logger.info("Simulating LLM call...")
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
Args:
llm_config: A dictionary specifying the LLM type and its parameters.
Example: {"type": "ollama", "model": "llama3"}
{"type": "hf", "model": "distilgpt2"}
None (for simulation mode)
Returns:
An instance of an LLMInterface subclass.
"""
if llm_config is None:
logger.info("No LLM config provided, defaulting to simulated chat.")
return SimulatedChat()
llm_type = llm_config.get("type", "simulated")
model = llm_config.get("model")
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(model=model, host=llm_config.get("host"))
elif llm_type == "hf":
return HFChat(model_name=model)
elif llm_type == "simulated":
return SimulatedChat()
else:
raise ValueError(f"Unknown LLM type: '{llm_type}'")

View File

@@ -0,0 +1,97 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, List
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendSearcherInterface
class BaseSearcher(LeannBackendSearcherInterface, ABC):
"""
Abstract base class for Leann searchers, containing common logic for
loading metadata, managing embedding servers, and handling file paths.
"""
def __init__(self, index_path: str, backend_module_name: str, **kwargs):
"""
Initializes the BaseSearcher.
Args:
index_path: Path to the Leann index file (e.g., '.../my_index.leann').
backend_module_name: The specific embedding server module to use
(e.g., 'leann_backend_hnsw.hnsw_embedding_server').
**kwargs: Additional keyword arguments.
"""
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.meta = kwargs.get("meta", self._load_meta())
if not self.meta:
raise ValueError("Searcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
)
def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
return pickle.load(f)
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
"""
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
server_started = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {kwargs.get('zmq_port')}")
@abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.
Must be implemented by subclasses.
"""
pass
def __del__(self):
"""Ensures the embedding server is stopped when the searcher is destroyed."""
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()