Merge remote-tracking branch 'origin/main' into financebench
This commit is contained in:
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
@@ -32,6 +32,16 @@ if not logger.handlers:
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||
try:
|
||||
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||
PROVIDER_OPTIONS = {}
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
|
||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
|
||||
continue
|
||||
|
||||
# Process the request
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Validation
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
[build-system]
|
||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
|
||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.3.3"
|
||||
dependencies = ["leann-core==0.3.3", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.3.4"
|
||||
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
|
||||
@@ -5,6 +5,8 @@ import os
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -237,6 +239,288 @@ def write_compact_format(
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HNSWComponents:
|
||||
original_hnsw_data: dict[str, Any]
|
||||
assign_probas_np: np.ndarray
|
||||
cum_nneighbor_per_level_np: np.ndarray
|
||||
levels_np: np.ndarray
|
||||
is_compact: bool
|
||||
compact_level_ptr: Optional[np.ndarray] = None
|
||||
compact_node_offsets_np: Optional[np.ndarray] = None
|
||||
compact_neighbors_data: Optional[list[int]] = None
|
||||
offsets_np: Optional[np.ndarray] = None
|
||||
neighbors_np: Optional[np.ndarray] = None
|
||||
storage_fourcc: int = NULL_INDEX_FOURCC
|
||||
storage_data: bytes = b""
|
||||
|
||||
|
||||
def _read_hnsw_structure(f) -> HNSWComponents:
|
||||
original_hnsw_data: dict[str, Any] = {}
|
||||
|
||||
hnsw_index_fourcc = read_struct(f, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
raise ValueError(
|
||||
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
|
||||
)
|
||||
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
|
||||
|
||||
assign_probas_np = read_numpy_vector(f, np.float64, "d")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
|
||||
levels_np = read_numpy_vector(f, np.int32, "i")
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
pos_before_compact = f.tell()
|
||||
is_compact_flag = None
|
||||
try:
|
||||
is_compact_flag = read_struct(f, "<?")
|
||||
except EOFError:
|
||||
is_compact_flag = None
|
||||
|
||||
if is_compact_flag:
|
||||
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
|
||||
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||
|
||||
storage_fourcc = read_struct(f, "<I")
|
||||
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
storage_data = f.read()
|
||||
|
||||
return HNSWComponents(
|
||||
original_hnsw_data=original_hnsw_data,
|
||||
assign_probas_np=assign_probas_np,
|
||||
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||
levels_np=levels_np,
|
||||
is_compact=True,
|
||||
compact_level_ptr=compact_level_ptr,
|
||||
compact_node_offsets_np=compact_node_offsets_np,
|
||||
compact_neighbors_data=compact_neighbors_data,
|
||||
storage_fourcc=storage_fourcc,
|
||||
storage_data=storage_data,
|
||||
)
|
||||
|
||||
# Non-compact case
|
||||
f.seek(pos_before_compact)
|
||||
|
||||
pos_before_probe = f.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f, "<B")
|
||||
if suspected_flag != 0x00:
|
||||
f.seek(pos_before_probe)
|
||||
except EOFError:
|
||||
f.seek(pos_before_probe)
|
||||
|
||||
offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||
neighbors_np = read_numpy_vector(f, np.int32, "i")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b""
|
||||
try:
|
||||
storage_fourcc = read_struct(f, "<I")
|
||||
storage_data = f.read()
|
||||
except EOFError:
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
|
||||
return HNSWComponents(
|
||||
original_hnsw_data=original_hnsw_data,
|
||||
assign_probas_np=assign_probas_np,
|
||||
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||
levels_np=levels_np,
|
||||
is_compact=False,
|
||||
offsets_np=offsets_np,
|
||||
neighbors_np=neighbors_np,
|
||||
storage_fourcc=storage_fourcc,
|
||||
storage_data=storage_data,
|
||||
)
|
||||
|
||||
|
||||
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
|
||||
with open(path, "rb") as f:
|
||||
return _read_hnsw_structure(f)
|
||||
|
||||
|
||||
def write_original_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
offsets_np,
|
||||
neighbors_np,
|
||||
storage_fourcc,
|
||||
storage_data,
|
||||
):
|
||||
"""Write non-compact HNSW data in original FAISS order."""
|
||||
|
||||
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||
|
||||
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||
write_numpy_vector(f_out, levels_np, "i")
|
||||
|
||||
write_numpy_vector(f_out, offsets_np, "Q")
|
||||
write_numpy_vector(f_out, neighbors_np, "i")
|
||||
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||
|
||||
f_out.write(struct.pack("<I", storage_fourcc))
|
||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
|
||||
"""Rewrite an HNSW index while dropping the embedded storage section."""
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||
original_hnsw_data: dict[str, Any] = {}
|
||||
|
||||
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
print(
|
||||
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
pos_before_compact = f_in.tell()
|
||||
is_compact_flag = None
|
||||
try:
|
||||
is_compact_flag = read_struct(f_in, "<?")
|
||||
except EOFError:
|
||||
is_compact_flag = None
|
||||
|
||||
if is_compact_flag:
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
|
||||
_storage_fourcc = read_struct(f_in, "<I")
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
_storage_data = f_in.read()
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
NULL_INDEX_FOURCC,
|
||||
b"",
|
||||
)
|
||||
else:
|
||||
f_in.seek(pos_before_compact)
|
||||
|
||||
pos_before_probe = f_in.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f_in, "<B")
|
||||
if suspected_flag != 0x00:
|
||||
f_in.seek(pos_before_probe)
|
||||
except EOFError:
|
||||
f_in.seek(pos_before_probe)
|
||||
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
|
||||
_storage_fourcc = None
|
||||
_storage_data = b""
|
||||
try:
|
||||
_storage_fourcc = read_struct(f_in, "<I")
|
||||
_storage_data = f_in.read()
|
||||
except EOFError:
|
||||
_storage_fourcc = NULL_INDEX_FOURCC
|
||||
|
||||
write_original_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
offsets_np,
|
||||
neighbors_np,
|
||||
NULL_INDEX_FOURCC,
|
||||
b"",
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
|
||||
return True
|
||||
except Exception as exc:
|
||||
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
# --- Main Conversion Logic ---
|
||||
|
||||
|
||||
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
pass
|
||||
|
||||
|
||||
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
|
||||
"""Convenience wrapper to prune embeddings in-place."""
|
||||
|
||||
temp_path = f"{index_filename}.prune.tmp"
|
||||
success = prune_hnsw_embeddings(index_filename, temp_path)
|
||||
if success:
|
||||
try:
|
||||
os.replace(temp_path, index_filename)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error(f"Failed to replace original index with pruned version: {exc}")
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return success
|
||||
|
||||
|
||||
# --- Script Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@@ -14,7 +14,7 @@ from leann.interface import (
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -101,6 +101,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
elif self.is_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
@@ -142,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.is_compact, self.is_pruned = (
|
||||
self.meta.get("is_compact", True),
|
||||
self.meta.get("is_pruned", True),
|
||||
)
|
||||
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
|
||||
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
|
||||
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
|
||||
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
|
||||
|
||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||
if not index_file.exists():
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
# Ensure we have handlers if none exist
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
stream_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
|
||||
if log_path:
|
||||
try:
|
||||
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
except Exception as exc: # pragma: no cover - best effort logging
|
||||
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||
try:
|
||||
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||
PROVIDER_OPTIONS = {}
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
@@ -167,7 +189,12 @@ def create_hnsw_embedding_server(
|
||||
):
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
@@ -217,7 +244,10 @@ def create_hnsw_embedding_server(
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
@@ -283,7 +313,12 @@ def create_hnsw_embedding_server(
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.3.3",
|
||||
"leann-core==0.3.4",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: ed96ff7dba...1d51f0c074
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -15,6 +15,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
@@ -38,6 +39,7 @@ def compute_embeddings(
|
||||
use_server: bool = True,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -71,6 +73,7 @@ def compute_embeddings(
|
||||
model_name,
|
||||
mode=mode,
|
||||
is_build=is_build,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -277,6 +280,7 @@ class LeannBuilder:
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
embedding_options: Optional[dict[str, Any]] = None,
|
||||
**backend_kwargs,
|
||||
):
|
||||
self.backend_name = backend_name
|
||||
@@ -299,6 +303,7 @@ class LeannBuilder:
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.embedding_options = embedding_options or {}
|
||||
|
||||
# Check if we need to use cosine distance for normalized embeddings
|
||||
normalized_embeddings_models = {
|
||||
@@ -406,6 +411,7 @@ class LeannBuilder:
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
provider_options=self.embedding_options,
|
||||
)[0]
|
||||
)
|
||||
path = Path(index_path)
|
||||
@@ -445,6 +451,7 @@ class LeannBuilder:
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
|
||||
@@ -482,14 +489,15 @@ class LeannBuilder:
|
||||
],
|
||||
}
|
||||
|
||||
if self.embedding_options:
|
||||
meta_data["embedding_options"] = self.embedding_options
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = (
|
||||
is_compact and is_recompute
|
||||
) # Pruned only if compact and recompute
|
||||
meta_data["is_pruned"] = bool(is_recompute)
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
@@ -615,18 +623,166 @@ class LeannBuilder:
|
||||
"embeddings_source": str(embeddings_file),
|
||||
}
|
||||
|
||||
if self.embedding_options:
|
||||
meta_data["embedding_options"] = self.embedding_options
|
||||
|
||||
# Add storage status flags for HNSW backend
|
||||
if self.backend_name == "hnsw":
|
||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||
meta_data["is_compact"] = is_compact
|
||||
meta_data["is_pruned"] = is_compact and is_recompute
|
||||
meta_data["is_pruned"] = bool(is_recompute)
|
||||
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
def update_index(self, index_path: str):
|
||||
"""Append new passages and vectors to an existing HNSW index."""
|
||||
if not self.chunks:
|
||||
raise ValueError("No new chunks provided for update.")
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
index_prefix = path.stem
|
||||
|
||||
meta_path = index_dir / f"{index_name}.meta.json"
|
||||
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
|
||||
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
|
||||
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
backend_name = meta.get("backend_name")
|
||||
if backend_name != self.backend_name:
|
||||
raise ValueError(
|
||||
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
|
||||
)
|
||||
|
||||
meta_backend_kwargs = meta.get("backend_kwargs", {})
|
||||
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
|
||||
if index_is_compact:
|
||||
raise ValueError(
|
||||
"Compact HNSW indices do not support in-place updates. Rebuild required."
|
||||
)
|
||||
|
||||
distance_metric = meta_backend_kwargs.get(
|
||||
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
|
||||
).lower()
|
||||
needs_recompute = bool(
|
||||
meta.get("is_pruned")
|
||||
or meta_backend_kwargs.get("is_recompute")
|
||||
or self.backend_kwargs.get("is_recompute")
|
||||
)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in self.chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
|
||||
embedding_dim = embeddings.shape[1]
|
||||
expected_dim = meta.get("dimensions")
|
||||
if expected_dim is not None and expected_dim != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
|
||||
)
|
||||
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
if hasattr(index, "is_recompute"):
|
||||
index.is_recompute = needs_recompute
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
if index.d != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||
)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info(
|
||||
"Appended %d passages to index '%s'. New total: %d",
|
||||
len(valid_chunks),
|
||||
index_path,
|
||||
len(offset_map),
|
||||
)
|
||||
|
||||
self.chunks.clear()
|
||||
|
||||
if needs_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
@@ -650,6 +806,7 @@ class LeannSearcher:
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||
# Delegate portability handling to PassageManager
|
||||
self.passage_manager = PassageManager(
|
||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||
@@ -661,6 +818,8 @@ class LeannSearcher:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
if self.embedding_options:
|
||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
|
||||
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
model_name: str, llm_type: str, host: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models(host)
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
available_models = check_ollama_models(resolved_host)
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
|
||||
class OllamaChat(LLMInterface):
|
||||
"""LLM interface for Ollama models."""
|
||||
|
||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||
self.model = model
|
||||
self.host = host
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||
self.host = resolve_ollama_host(host)
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||
try:
|
||||
import requests
|
||||
|
||||
# Check if the Ollama server is responsive
|
||||
if host:
|
||||
requests.get(host)
|
||||
if self.host:
|
||||
requests.get(self.host)
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
|
||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||
logger.error(
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
)
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = resolve_openai_base_url(base_url)
|
||||
self.api_key = resolve_openai_api_key(api_key)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||
logger.info(
|
||||
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
self.client = openai.OpenAI(api_key=self.api_key)
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
if llm_type == "ollama":
|
||||
return OllamaChat(
|
||||
model=model or "llama3:8b",
|
||||
host=llm_config.get("host", "http://localhost:11434"),
|
||||
host=llm_config.get("host"),
|
||||
)
|
||||
elif llm_type == "hf":
|
||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||
elif llm_type == "openai":
|
||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||
return OpenAIChat(
|
||||
model=model or "gpt-4o",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "simulated":
|
||||
|
||||
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Enhanced chunking utilities with AST-aware code chunking support.
|
||||
Packaged within leann-core so installed wheels can import it reliably.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code file extensions supported by astchunk
|
||||
CODE_EXTENSIONS = {
|
||||
".py": "python",
|
||||
".java": "java",
|
||||
".cs": "csharp",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".js": "typescript",
|
||||
".jsx": "typescript",
|
||||
}
|
||||
|
||||
|
||||
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||
"""Separate documents into code files and regular text files."""
|
||||
if code_extensions is None:
|
||||
code_extensions = CODE_EXTENSIONS
|
||||
|
||||
code_docs = []
|
||||
text_docs = []
|
||||
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
||||
if file_path:
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
if file_ext in code_extensions:
|
||||
doc.metadata["language"] = code_extensions[file_ext]
|
||||
doc.metadata["is_code"] = True
|
||||
code_docs.append(doc)
|
||||
else:
|
||||
doc.metadata["is_code"] = False
|
||||
text_docs.append(doc)
|
||||
else:
|
||||
doc.metadata["is_code"] = False
|
||||
text_docs.append(doc)
|
||||
|
||||
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
||||
return code_docs, text_docs
|
||||
|
||||
|
||||
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return CODE_EXTENSIONS.get(ext)
|
||||
|
||||
|
||||
def create_ast_chunks(
|
||||
documents,
|
||||
max_chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
metadata_template: str = "default",
|
||||
) -> list[str]:
|
||||
"""Create AST-aware chunks from code documents using astchunk.
|
||||
|
||||
Falls back to traditional chunking if astchunk is unavailable.
|
||||
"""
|
||||
try:
|
||||
from astchunk import ASTChunkBuilder # optional dependency
|
||||
except ImportError as e:
|
||||
logger.error(f"astchunk not available: {e}")
|
||||
logger.info("Falling back to traditional chunking for code files")
|
||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||
|
||||
all_chunks = []
|
||||
for doc in documents:
|
||||
language = doc.metadata.get("language")
|
||||
if not language:
|
||||
logger.warning("No language detected; falling back to traditional chunking")
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
continue
|
||||
|
||||
try:
|
||||
configs = {
|
||||
"max_chunk_size": max_chunk_size,
|
||||
"language": language,
|
||||
"metadata_template": metadata_template,
|
||||
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||
}
|
||||
|
||||
repo_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
"creation_date": doc.metadata.get("creation_date", ""),
|
||||
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
||||
}
|
||||
configs["repo_level_metadata"] = repo_metadata
|
||||
|
||||
chunk_builder = ASTChunkBuilder(**configs)
|
||||
code_content = doc.get_content()
|
||||
if not code_content or not code_content.strip():
|
||||
logger.warning("Empty code content, skipping")
|
||||
continue
|
||||
|
||||
chunks = chunk_builder.chunkify(code_content)
|
||||
for chunk in chunks:
|
||||
if hasattr(chunk, "text"):
|
||||
chunk_text = chunk.text
|
||||
elif isinstance(chunk, dict) and "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
elif isinstance(chunk, str):
|
||||
chunk_text = chunk
|
||||
else:
|
||||
chunk_text = str(chunk)
|
||||
|
||||
if chunk_text and chunk_text.strip():
|
||||
all_chunks.append(chunk_text.strip())
|
||||
|
||||
logger.info(
|
||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||
logger.info("Falling back to traditional chunking")
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
|
||||
return all_chunks
|
||||
|
||||
|
||||
def create_traditional_chunks(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[str]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||
if chunk_size <= 0:
|
||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||
chunk_size = 256
|
||||
if chunk_overlap < 0:
|
||||
chunk_overlap = 0
|
||||
if chunk_overlap >= chunk_size:
|
||||
chunk_overlap = chunk_size // 2
|
||||
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=" ",
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
try:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
except Exception as e:
|
||||
logger.error(f"Traditional chunking failed for document: {e}")
|
||||
content = doc.get_content()
|
||||
if content and content.strip():
|
||||
all_texts.append(content.strip())
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
def create_text_chunks(
|
||||
documents,
|
||||
chunk_size: int = 256,
|
||||
chunk_overlap: int = 128,
|
||||
use_ast_chunking: bool = False,
|
||||
ast_chunk_size: int = 512,
|
||||
ast_chunk_overlap: int = 64,
|
||||
code_file_extensions: Optional[list[str]] = None,
|
||||
ast_fallback_traditional: bool = True,
|
||||
) -> list[str]:
|
||||
"""Create text chunks from documents with optional AST support for code files."""
|
||||
if not documents:
|
||||
logger.warning("No documents provided for chunking")
|
||||
return []
|
||||
|
||||
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||
if code_file_extensions:
|
||||
ext_mapping = {
|
||||
".py": "python",
|
||||
".java": "java",
|
||||
".cs": "c_sharp",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
}
|
||||
for ext in code_file_extensions:
|
||||
if ext.lower() not in local_code_extensions:
|
||||
if ext.lower() in ext_mapping:
|
||||
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||
else:
|
||||
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||
|
||||
all_chunks = []
|
||||
if use_ast_chunking:
|
||||
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||
if code_docs:
|
||||
try:
|
||||
all_chunks.extend(
|
||||
create_ast_chunks(
|
||||
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"AST chunking failed: {e}")
|
||||
if ast_fallback_traditional:
|
||||
all_chunks.extend(
|
||||
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||
)
|
||||
else:
|
||||
raise
|
||||
if text_docs:
|
||||
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||
else:
|
||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
|
||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||
return all_chunks
|
||||
@@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -10,6 +9,7 @@ from tqdm import tqdm
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .registry import register_project_directory
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
|
||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||
@@ -124,6 +124,24 @@ Examples:
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||
)
|
||||
@@ -239,6 +257,11 @@ Examples:
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument(
|
||||
"query",
|
||||
nargs="?",
|
||||
help="Question to ask (omit for prompt or when using --interactive)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
@@ -249,7 +272,12 @@ Examples:
|
||||
ask_parser.add_argument(
|
||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
||||
)
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
||||
)
|
||||
@@ -278,6 +306,18 @@ Examples:
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# List command
|
||||
subparsers.add_parser("list", help="List all indexes")
|
||||
@@ -1216,13 +1256,8 @@ Examples:
|
||||
if use_ast:
|
||||
print("🧠 Using AST-aware chunking for code files")
|
||||
try:
|
||||
# Import enhanced chunking utilities
|
||||
# Add apps directory to path to import chunking utilities
|
||||
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
|
||||
if apps_dir.exists():
|
||||
sys.path.insert(0, str(apps_dir))
|
||||
|
||||
from chunking import create_text_chunks
|
||||
# Import enhanced chunking utilities from packaged module
|
||||
from .chunking_utils import create_text_chunks
|
||||
|
||||
# Use enhanced chunking with AST support
|
||||
all_texts = create_text_chunks(
|
||||
@@ -1237,7 +1272,9 @@ Examples:
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking")
|
||||
print(
|
||||
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||
)
|
||||
use_ast = False
|
||||
|
||||
if not use_ast:
|
||||
@@ -1329,10 +1366,20 @@ Examples:
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
@@ -1480,11 +1527,38 @@ Examples:
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
llm_config["host"] = resolve_ollama_host(args.host)
|
||||
elif args.llm == "openai":
|
||||
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
|
||||
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||
if resolved_api_key:
|
||||
llm_config["api_key"] = resolved_api_key
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
llm_kwargs: dict[str, Any] = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
def _ask_once(prompt: str) -> None:
|
||||
response = chat.ask(
|
||||
prompt,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
initial_query = (args.query or "").strip()
|
||||
|
||||
if args.interactive:
|
||||
if initial_query:
|
||||
_ask_once(initial_query)
|
||||
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
@@ -1497,41 +1571,14 @@ Examples:
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
_ask_once(user_input)
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
query = initial_query or input("Enter your question: ").strip()
|
||||
if not query:
|
||||
print("No question provided. Exiting.")
|
||||
return
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
_ask_once(query)
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
@@ -31,6 +33,7 @@ def compute_embeddings(
|
||||
adaptive_optimization: bool = True,
|
||||
manual_tokenize: bool = False,
|
||||
max_length: int = 512,
|
||||
provider_options: Optional[dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
@@ -46,6 +49,8 @@ def compute_embeddings(
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
provider_options = provider_options or {}
|
||||
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
@@ -57,11 +62,21 @@ def compute_embeddings(
|
||||
max_length=max_length,
|
||||
)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
return compute_embeddings_openai(
|
||||
texts,
|
||||
model_name,
|
||||
base_url=provider_options.get("base_url"),
|
||||
api_key=provider_options.get("api_key"),
|
||||
)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
elif mode == "ollama":
|
||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||
return compute_embeddings_ollama(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
host=provider_options.get("host"),
|
||||
)
|
||||
elif mode == "gemini":
|
||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||
else:
|
||||
@@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers(
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
def compute_embeddings_openai(
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import os
|
||||
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
@@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||
)
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
resolved_base_url = resolve_openai_base_url(base_url)
|
||||
resolved_api_key = resolve_openai_api_key(api_key)
|
||||
|
||||
if not resolved_api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = "openai_client"
|
||||
cache_key = f"openai_client::{resolved_base_url}"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
|
||||
@@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
||||
|
||||
|
||||
def compute_embeddings_ollama(
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
is_build: bool = False,
|
||||
host: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API with simplified batch processing.
|
||||
@@ -518,7 +541,7 @@ def compute_embeddings_ollama(
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (default: http://localhost:11434)
|
||||
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
@@ -533,17 +556,19 @@ def compute_embeddings_ollama(
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||
)
|
||||
|
||||
# Check if Ollama is running
|
||||
try:
|
||||
response = requests.get(f"{host}/api/version", timeout=5)
|
||||
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = (
|
||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
||||
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||
"Please ensure Ollama is running:\n"
|
||||
" • macOS/Linux: ollama serve\n"
|
||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||
@@ -555,7 +580,7 @@ def compute_embeddings_ollama(
|
||||
|
||||
# Check if model exists and provide helpful suggestions
|
||||
try:
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
model_names = [model["name"] for model in models.get("models", [])]
|
||||
@@ -618,7 +643,9 @@ def compute_embeddings_ollama(
|
||||
# Verify the model supports embeddings by testing it
|
||||
try:
|
||||
test_response = requests.post(
|
||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
||||
f"{resolved_host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": "test"},
|
||||
timeout=10,
|
||||
)
|
||||
if test_response.status_code != 200:
|
||||
error_msg = (
|
||||
@@ -665,7 +692,7 @@ def compute_embeddings_ollama(
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{host}/api/embeddings",
|
||||
f"{resolved_host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,8 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .settings import encode_provider_options
|
||||
|
||||
# Lightweight, self-contained server manager with no cross-process inspection
|
||||
|
||||
# Set up logging based on environment variable
|
||||
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
|
||||
) -> tuple[bool, int]:
|
||||
"""Start the embedding server."""
|
||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||
provider_options = kwargs.pop("provider_options", None)
|
||||
|
||||
config_signature = {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
|
||||
# If this manager already has a live server, just reuse it
|
||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
||||
if (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
and self._server_config == config_signature
|
||||
):
|
||||
logger.info("Reusing in-process server")
|
||||
return True, self.server_port
|
||||
|
||||
# Configuration changed, stop existing server before starting a new one
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
logger.info("Existing server configuration differs; restarting embedding server")
|
||||
self.stop_server()
|
||||
|
||||
# For Colab environment, use a different strategy
|
||||
if _is_colab_environment():
|
||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||
return self._start_server_colab(
|
||||
port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
provider_options=provider_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Always pick a fresh available port
|
||||
try:
|
||||
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
|
||||
return False, port
|
||||
|
||||
# Start a new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
return self._start_new_server(
|
||||
actual_port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _start_server_colab(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
provider_options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""Start server with Colab-specific configuration."""
|
||||
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
|
||||
|
||||
try:
|
||||
# In Colab, we'll use a more direct approach
|
||||
self._launch_server_process_colab(command, actual_port)
|
||||
return self._wait_for_server_ready_colab(actual_port)
|
||||
self._launch_server_process_colab(
|
||||
command,
|
||||
actual_port,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||
if started:
|
||||
self._server_config = {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
return started, ready_port
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||
return False, actual_port
|
||||
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
|
||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
self._launch_server_process(
|
||||
command,
|
||||
port,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready(port)
|
||||
if started:
|
||||
self._server_config = config_signature or {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
return started, ready_port
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
|
||||
|
||||
return command
|
||||
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
def _launch_server_process(
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
provider_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
@@ -193,14 +262,20 @@ class EmbeddingServerManager:
|
||||
|
||||
# Start embedding server subprocess
|
||||
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||
env = os.environ.copy()
|
||||
encoded_options = encode_provider_options(provider_options)
|
||||
if encoded_options:
|
||||
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=stdout_target,
|
||||
stderr=stderr_target,
|
||||
env=env,
|
||||
)
|
||||
self.server_port = port
|
||||
# Record config for in-process reuse
|
||||
# Record config for in-process reuse (best effort; refined later when ready)
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
@@ -212,12 +287,14 @@ class EmbeddingServerManager:
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
@@ -322,16 +399,27 @@ class EmbeddingServerManager:
|
||||
# Removed: cross-process adoption no longer supported
|
||||
return
|
||||
|
||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||
def _launch_server_process_colab(
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
provider_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Launch the server process with Colab-specific settings."""
|
||||
logger.info(f"Colab Command: {' '.join(command)}")
|
||||
|
||||
# In Colab, we need to be more careful about process management
|
||||
env = os.environ.copy()
|
||||
encoded_options = encode_provider_options(provider_options)
|
||||
if encoded_options:
|
||||
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||
@@ -345,6 +433,7 @@ class EmbeddingServerManager:
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
|
||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||
|
||||
@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||
|
||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
self.embedding_options = self.meta.get("embedding_options", {})
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name,
|
||||
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
from .embedding_compute import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
return compute_embeddings(
|
||||
[query],
|
||||
self.embedding_model,
|
||||
embedding_mode,
|
||||
provider_options=self.embedding_options,
|
||||
)
|
||||
|
||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||
"""Compute embeddings using the ZMQ embedding server."""
|
||||
|
||||
74
packages/leann-core/src/leann/settings.py
Normal file
74
packages/leann-core/src/leann/settings.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Runtime configuration helpers for LEANN."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
def _clean_url(value: str) -> str:
|
||||
"""Normalize URL strings by stripping trailing slashes."""
|
||||
|
||||
return value.rstrip("/") if value else value
|
||||
|
||||
|
||||
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||
"""Resolve the Ollama-compatible endpoint to use."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||
os.getenv("LEANN_OLLAMA_HOST"),
|
||||
os.getenv("OLLAMA_HOST"),
|
||||
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||
|
||||
|
||||
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||
"""Resolve the base URL for OpenAI-compatible services."""
|
||||
|
||||
candidates = (
|
||||
explicit,
|
||||
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||
os.getenv("OPENAI_BASE_URL"),
|
||||
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||
)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate:
|
||||
return _clean_url(candidate)
|
||||
|
||||
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||
|
||||
|
||||
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||
"""Resolve the API key for OpenAI-compatible services."""
|
||||
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||
"""Serialize provider options for child processes."""
|
||||
|
||||
if not options:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.dumps(options)
|
||||
except (TypeError, ValueError):
|
||||
# Fall back to empty payload if serialization fails
|
||||
return None
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
Reference in New Issue
Block a user