fix readme

This commit is contained in:
yichuan-w
2025-10-08 21:38:55 +00:00
parent 3ec5e8d035
commit 5be0c144ad
72 changed files with 16608 additions and 4175 deletions

View File

@@ -343,7 +343,8 @@ class DiskannSearcher(BaseSearcher):
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
# 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
"cache_mechanism": kwargs.get("cache_mechanism", 1),
"pq_prefix": "",
"partition_prefix": partition_prefix,
}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import numpy as np
import zmq
@@ -32,6 +32,16 @@ if not logger.handlers:
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation

View File

@@ -5,6 +5,8 @@ import os
import struct
import sys
import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
@@ -237,6 +239,288 @@ def write_compact_format(
f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic ---
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@@ -14,7 +14,7 @@ from leann.interface import (
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
@@ -90,8 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact:
self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
@@ -133,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = (
self.meta.get("is_compact", True),
self.meta.get("is_pruned", True),
)
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
@@ -150,6 +161,16 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search(
self,
query: np.ndarray,
@@ -248,6 +269,19 @@ class HNSWSearcher(BaseSearcher):
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
if self._id_map:
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import msgpack
import numpy as np
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
# Ensure we have handlers if none exist
if not logger.handlers:
handler = logging.StreamHandler()
stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server(
@@ -92,6 +114,35 @@ def create_hnsw_embedding_server(
embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []
try:
meta_path = Path(passages_file)
base = meta_path.name
if base.endswith(".meta.json"):
base = base[: -len(".meta.json")] # e.g., laion_index.leann
if base.endswith(".leann"):
base = base[: -len(".leann")] # e.g., laion_index
idmap_file = meta_path.parent / f"{base}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
id_map = [line.rstrip("\n") for line in f]
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
else:
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def _map_node_id(nid) -> str:
try:
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
idx = int(nid)
if 0 <= idx < len(id_map):
return id_map[idx]
except Exception:
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
@@ -138,7 +189,12 @@ def create_hnsw_embedding_server(
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -168,13 +224,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
@@ -187,7 +244,10 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -238,13 +298,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
@@ -252,7 +313,12 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)

View File

@@ -18,14 +18,16 @@ dependencies = [
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"sentence-transformers>=3.0.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
# breaks Python 3.9 environments.
"transformers>=4.30.0,<4.46",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
@@ -40,7 +42,7 @@ dependencies = [
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
]

View File

@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
import json
import logging
import os
import pickle
import re
import subprocess
@@ -15,10 +16,13 @@ from pathlib import Path
from typing import Any, Literal, Optional, Union
import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interactive_utils import create_api_session
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY
@@ -38,6 +42,7 @@ def compute_embeddings(
use_server: bool = True,
port: Optional[int] = None,
is_build=False,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -71,6 +76,7 @@ def compute_embeddings(
model_name,
mode=mode,
is_build=is_build,
provider_options=provider_options,
)
@@ -277,6 +283,7 @@ class LeannBuilder:
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
**backend_kwargs,
):
self.backend_name = backend_name
@@ -299,6 +306,7 @@ class LeannBuilder:
self.embedding_model = embedding_model
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.embedding_options = embedding_options or {}
# Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
@@ -406,6 +414,7 @@ class LeannBuilder:
self.embedding_model,
self.embedding_mode,
use_server=False,
provider_options=self.embedding_options,
)[0]
)
path = Path(index_path)
@@ -445,8 +454,20 @@ class LeannBuilder:
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
string_ids = [chunk["id"] for chunk in self.chunks]
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
@@ -471,14 +492,15 @@ class LeannBuilder:
],
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = (
is_compact and is_recompute
) # Pruned only if compact and recompute
meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
@@ -565,6 +587,17 @@ class LeannBuilder:
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
# Persist ID map (order == embeddings order)
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)
@@ -593,18 +626,242 @@ class LeannBuilder:
"embeddings_source": str(embeddings_file),
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute
meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
def update_index(self, index_path: str):
"""Append new passages and vectors to an existing HNSW index."""
if not self.chunks:
raise ValueError("No new chunks provided for update.")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_prefix = path.stem
meta_path = index_dir / f"{index_name}.meta.json"
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
index_file = index_dir / f"{index_prefix}.index"
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
backend_name = meta.get("backend_name")
if backend_name != self.backend_name:
raise ValueError(
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
)
meta_backend_kwargs = meta.get("backend_kwargs", {})
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
if index_is_compact:
raise ValueError(
"Compact HNSW indices do not support in-place updates. Rebuild required."
)
distance_metric = meta_backend_kwargs.get(
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
).lower()
needs_recompute = bool(
meta.get("is_pruned")
or meta_backend_kwargs.get("is_recompute")
or self.backend_kwargs.get("is_recompute")
)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in self.chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
embedding_dim = embeddings.shape[1]
expected_dim = meta.get("dimensions")
if expected_dim is not None and expected_dim != embedding_dim:
raise ValueError(
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
)
from leann_backend_hnsw import faiss # type: ignore
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
# Append passages/offsets before we attempt index.add so the ZMQ server
# can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
logger.warning(
"Embedding server started on port %s instead of requested %s. "
"Using reassigned port.",
actual_port,
requested_zmq_port,
)
try:
index.hnsw.zmq_port = actual_port
except AttributeError:
pass
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(
"Appended %d passages to index '%s'. New total: %d",
len(valid_chunks),
index_path,
len(offset_map),
)
self.chunks.clear()
if needs_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -628,6 +885,7 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta_data.get("embedding_options", {})
# Delegate portability handling to PassageManager
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
@@ -639,6 +897,8 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
@@ -983,19 +1243,14 @@ class LeannChat:
return ans
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ["quit", "exit"]:
break
if not user_input:
continue
response = self.ask(user_input)
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
"""Start interactive chat session."""
session = create_api_session()
def handle_query(user_input: str):
response = self.ask(user_input)
print(f"Leann: {response}")
session.run_interactive_loop(handle_query)
def cleanup(self):
"""Explicitly cleanup embedding server resources.

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
if self.host:
requests.get(self.host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host)
model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error:
raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
@@ -541,11 +546,30 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
Args:
model_name (str): Name of the Hugging Face model to load.
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models as this can pose
a security risk if the model repository is compromised.
"""
def __init__(
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Security warning when trust_remote_code is enabled
if trust_remote_code:
logger.warning(
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
"repository is compromised."
)
self.trust_remote_code = trust_remote_code
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
@@ -583,14 +607,16 @@ class HFChat(LLMInterface):
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=self.trust_remote_code
)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
trust_remote_code=self.trust_remote_code,
)
logger.info(f"Successfully loaded {model_name}")
finally:
@@ -737,21 +763,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -841,12 +877,19 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"),
host=llm_config.get("host"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
return HFChat(
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
trust_remote_code=llm_config.get("trust_remote_code", False),
)
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":

View File

@@ -8,7 +8,9 @@ from llama_index.core.node_parser import SentenceSplitter
from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher
from .interactive_utils import create_cli_session
from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -123,6 +125,24 @@ Examples:
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)",
)
build_parser.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
build_parser.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
build_parser.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index"
)
@@ -238,6 +258,11 @@ Examples:
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument(
"--llm",
type=str,
@@ -248,7 +273,12 @@ Examples:
ask_parser.add_argument(
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
)
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument(
"--host",
type=str,
default=None,
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
ask_parser.add_argument(
"--interactive", "-i", action="store_true", help="Interactive chat mode"
)
@@ -277,6 +307,18 @@ Examples:
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
ask_parser.add_argument(
"--api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
)
ask_parser.add_argument(
"--api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# List command
subparsers.add_parser("list", help="List all indexes")
@@ -1325,10 +1367,20 @@ Examples:
print(f"Building index '{index_name}' with {args.backend} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,
@@ -1476,58 +1528,49 @@ Examples:
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
llm_config["host"] = resolve_ollama_host(args.host)
elif args.llm == "openai":
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key:
llm_config["api_key"] = resolved_api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive:
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
# Create interactive session
session = create_cli_session(index_name)
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if initial_query:
_ask_once(initial_query)
if not user_input:
continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
session.run_interactive_loop(_ask_once)
else:
query = input("Enter your question: ").strip()
if query:
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
query = initial_query or input("Enter your question: ").strip()
if not query:
print("No question provided. Exiting.")
return
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
_ask_once(query)
async def run(self, args=None):
parser = self.create_parser()

View File

@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
import logging
import os
import time
from typing import Any
from typing import Any, Optional
import numpy as np
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -31,6 +33,7 @@ def compute_embeddings(
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -46,6 +49,8 @@ def compute_embeddings(
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
provider_options = provider_options or {}
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
@@ -57,11 +62,21 @@ def compute_embeddings(
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
return compute_embeddings_openai(
texts,
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
return compute_embeddings_ollama(
texts,
model_name,
is_build=is_build,
host=provider_options.get("host"),
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else:
@@ -168,32 +183,73 @@ def compute_embeddings_sentence_transformers(
}
try:
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
# Try loading with advanced parameters first (newer versions)
local_model_kwargs = model_kwargs.copy()
local_tokenizer_kwargs = tokenizer_kwargs.copy()
local_model_kwargs["local_files_only"] = True
local_tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
model_kwargs=local_model_kwargs,
tokenizer_kwargs=local_tokenizer_kwargs,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + optimized)")
except TypeError as e:
if "model_kwargs" in str(e) or "tokenizer_kwargs" in str(e):
logger.warning(
f"Advanced parameters not supported ({e}), using basic initialization..."
)
# Fallback to basic initialization for older versions
try:
model = SentenceTransformer(
model_name,
device=device,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + basic)")
except Exception as e2:
logger.warning(f"Local loading failed ({e2}), trying network download...")
model = SentenceTransformer(
model_name,
device=device,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + basic)")
else:
raise
except Exception as e:
logger.warning(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
# Fallback to network loading with advanced parameters
try:
network_model_kwargs = model_kwargs.copy()
network_tokenizer_kwargs = tokenizer_kwargs.copy()
network_model_kwargs["local_files_only"] = False
network_tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=network_model_kwargs,
tokenizer_kwargs=network_tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
except TypeError as e2:
if "model_kwargs" in str(e2) or "tokenizer_kwargs" in str(e2):
logger.warning(
f"Advanced parameters not supported ({e2}), using basic network loading..."
)
model = SentenceTransformer(
model_name,
device=device,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + basic)")
else:
raise
# Apply additional optimizations based on mode
if use_fp16 and device in ["cuda", "mps"]:
@@ -353,12 +409,15 @@ def compute_embeddings_sentence_transformers(
return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(
texts: list[str],
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import os
import openai
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
@@ -373,16 +432,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = "openai_client"
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
@@ -507,7 +568,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
texts: list[str],
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
@@ -518,7 +582,7 @@ def compute_embeddings_ollama(
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
host: Ollama host URL (defaults to environment or http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -533,17 +597,19 @@ def compute_embeddings_ollama(
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
resolved_host = resolve_ollama_host(host)
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response = requests.get(f"{resolved_host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n"
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
@@ -555,7 +621,7 @@ def compute_embeddings_ollama(
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
@@ -618,7 +684,9 @@ def compute_embeddings_ollama(
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
)
if test_response.status_code != 200:
error_msg = (
@@ -665,7 +733,7 @@ def compute_embeddings_ollama(
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)

View File

@@ -1,4 +1,5 @@
import atexit
import json
import logging
import os
import socket
@@ -8,6 +9,8 @@ import time
from pathlib import Path
from typing import Optional
from .settings import encode_provider_options
# Lightweight, self-contained server manager with no cross-process inspection
# Set up logging based on environment variable
@@ -46,6 +49,85 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager:
"""
A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -82,16 +164,42 @@ class EmbeddingServerManager:
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = self._build_config_signature(
model_name=model_name,
embedding_mode=embedding_mode,
provider_options=provider_options,
passages_file=passages_file,
)
# If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port:
if (
self.server_process
and self.server_process.poll() is None
and self.server_port
and self._server_config == config_signature
):
logger.info("Reusing in-process server")
return True, self.server_port
# Configuration changed, stop existing server before starting a new one
if self.server_process and self.server_process.poll() is None:
logger.info("Existing server configuration differs; restarting embedding server")
self.stop_server()
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
return self._start_server_colab(
port,
model_name,
embedding_mode,
config_signature=config_signature,
provider_options=provider_options,
**kwargs,
)
# Always pick a fresh available port
try:
@@ -101,13 +209,40 @@ class EmbeddingServerManager:
return False, port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
return self._start_new_server(
actual_port,
model_name,
embedding_mode,
provider_options=provider_options,
config_signature=config_signature,
**kwargs,
)
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
@@ -125,8 +260,21 @@ class EmbeddingServerManager:
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
self._launch_server_process_colab(
command,
actual_port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
@@ -134,7 +282,13 @@ class EmbeddingServerManager:
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
self,
port: int,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
@@ -142,8 +296,21 @@ class EmbeddingServerManager:
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
self._launch_server_process(
command,
port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e:
logger.error(f"Failed to start embedding server: {e}")
return False, port
@@ -173,7 +340,14 @@ class EmbeddingServerManager:
return command
def _launch_server_process(self, command: list, port: int) -> None:
def _launch_server_process(
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
@@ -193,32 +367,43 @@ class EmbeddingServerManager:
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
env=env,
)
self.server_port = port
# Record config for in-process reuse
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
# Record config for in-process reuse (best effort; refined later when ready)
if config_signature is not None:
self._server_config = config_signature
else: # Fallback for unexpected code paths
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
@@ -322,16 +507,29 @@ class EmbeddingServerManager:
# Removed: cross-process adoption no longer supported
return
def _launch_server_process_colab(self, command: list, port: int) -> None:
def _launch_server_process_colab(
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=env,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
@@ -341,11 +539,15 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
if config_signature is not None:
self._server_config = config_signature
else:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -0,0 +1,189 @@
"""
Interactive session utilities for LEANN applications.
Provides shared readline functionality and command handling across
CLI, API, and RAG example interactive modes.
"""
import atexit
import os
from pathlib import Path
from typing import Callable, Optional
# Try to import readline with fallback for Windows
try:
import readline
HAS_READLINE = True
except ImportError:
# Windows doesn't have readline by default
HAS_READLINE = False
readline = None
class InteractiveSession:
"""Manages interactive session with optional readline support and common commands."""
def __init__(
self,
history_name: str,
prompt: str = "You: ",
welcome_message: str = "",
):
"""
Initialize interactive session with optional readline support.
Args:
history_name: Name for history file (e.g., "cli", "api_chat")
(ignored if readline not available)
prompt: Input prompt to display
welcome_message: Message to show when starting session
Note:
On systems without readline (e.g., Windows), falls back to basic input()
with limited functionality (no history, no line editing).
"""
self.history_name = history_name
self.prompt = prompt
self.welcome_message = welcome_message
self._setup_complete = False
def setup_readline(self):
"""Setup readline with history support (if available)."""
if self._setup_complete:
return
if not HAS_READLINE:
# Readline not available (likely Windows), skip setup
self._setup_complete = True
return
# History file setup
history_dir = Path.home() / ".leann" / "history"
history_dir.mkdir(parents=True, exist_ok=True)
history_file = history_dir / f"{self.history_name}.history"
# Load history if exists
try:
readline.read_history_file(str(history_file))
readline.set_history_length(1000)
except FileNotFoundError:
pass
# Save history on exit
atexit.register(readline.write_history_file, str(history_file))
# Optional: Enable vi editing mode (commented out by default)
# readline.parse_and_bind("set editing-mode vi")
self._setup_complete = True
def _show_help(self):
"""Show available commands."""
print("Commands:")
print(" quit/exit/q - Exit the chat")
print(" help - Show this help message")
print(" clear - Clear screen")
print(" history - Show command history")
def _show_history(self):
"""Show command history."""
if not HAS_READLINE:
print(" History not available (readline not supported on this system)")
return
history_length = readline.get_current_history_length()
if history_length == 0:
print(" No history available")
return
for i in range(history_length):
item = readline.get_history_item(i + 1)
if item:
print(f" {i + 1}: {item}")
def get_user_input(self) -> Optional[str]:
"""
Get user input with readline support.
Returns:
User input string, or None if EOF (Ctrl+D)
"""
try:
return input(self.prompt).strip()
except KeyboardInterrupt:
print("\n(Use 'quit' to exit)")
return "" # Return empty string to continue
except EOFError:
print("\nGoodbye!")
return None
def run_interactive_loop(self, handler_func: Callable[[str], None]):
"""
Run the interactive loop with a custom handler function.
Args:
handler_func: Function to handle user input that's not a built-in command
Should accept a string and handle the user's query
"""
self.setup_readline()
if self.welcome_message:
print(self.welcome_message)
while True:
user_input = self.get_user_input()
if user_input is None: # EOF (Ctrl+D)
break
if not user_input: # Empty input or KeyboardInterrupt
continue
# Handle built-in commands
command = user_input.lower()
if command in ["quit", "exit", "q"]:
print("Goodbye!")
break
elif command == "help":
self._show_help()
elif command == "clear":
os.system("clear" if os.name != "nt" else "cls")
elif command == "history":
self._show_history()
else:
# Regular user input - pass to handler
try:
handler_func(user_input)
except Exception as e:
print(f"Error: {e}")
def create_cli_session(index_name: str) -> InteractiveSession:
"""Create an interactive session for CLI usage."""
return InteractiveSession(
history_name=index_name,
prompt="\nYou: ",
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
+ "=" * 40,
)
def create_api_session() -> InteractiveSession:
"""Create an interactive session for API chat."""
return InteractiveSession(
history_name="api_chat",
prompt="You: ",
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
+ "=" * 40,
)
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
"""Create an interactive session for RAG examples."""
return InteractiveSession(
history_name=f"{app_name}_rag",
prompt="You: ",
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
+ "=" * 40,
)

View File

@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta.get("embedding_options", {})
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name,
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
return compute_embeddings(
[query],
self.embedding_model,
embedding_mode,
provider_options=self.embedding_options,
)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""

View File

@@ -0,0 +1,74 @@
"""Runtime configuration helpers for LEANN."""
from __future__ import annotations
import json
import os
from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
def _clean_url(value: str) -> str:
"""Normalize URL strings by stripping trailing slashes."""
return value.rstrip("/") if value else value
def resolve_ollama_host(explicit: str | None = None) -> str:
"""Resolve the Ollama-compatible endpoint to use."""
candidates = (
explicit,
os.getenv("LEANN_LOCAL_LLM_HOST"),
os.getenv("LEANN_OLLAMA_HOST"),
os.getenv("OLLAMA_HOST"),
os.getenv("LOCAL_LLM_ENDPOINT"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OLLAMA_HOST)
def resolve_openai_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for OpenAI-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_OPENAI_BASE_URL"),
os.getenv("OPENAI_BASE_URL"),
os.getenv("LOCAL_OPENAI_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services."""
if explicit:
return explicit
return os.getenv("OPENAI_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes."""
if not options:
return None
try:
return json.dumps(options)
except (TypeError, ValueError):
# Fall back to empty payload if serialization fails
return None