Compare commits

...

8 Commits

Author SHA1 Message Date
Andy Lee
ec5e9ac33b feat: chat on mps 2025-07-12 06:07:43 +00:00
Andy Lee
d288946173 Merge remote-tracking branch 'origin/main' into datastore-reproduce 2025-07-12 05:42:16 +00:00
Andy Lee
0da08fbe38 refactor: chat and base searcher 2025-07-11 16:34:12 +00:00
Andy Lee
8bffb1e5b8 feat: reproducible research datas, rpj_wiki & dpr 2025-07-11 02:58:04 +00:00
Andy Lee
16705fc44a refactor: passage structure 2025-07-06 21:48:38 +00:00
Andy Lee
5611f708e9 docs: embedding pruning 2025-07-06 19:50:01 +00:00
Andy Lee
b4ae57b2c0 feat: auto discovery of packages and fix passage gen for diskann 2025-07-06 05:05:49 +00:00
Andy Lee
5659174635 fix: diskann zmq port and passages 2025-07-06 04:14:15 +00:00
22 changed files with 5070 additions and 3681 deletions

View File

@@ -74,7 +74,7 @@ def main():
print(f"⏱️ Basic search time: {basic_time:.3f} seconds") print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<") print(">>> Basic search results <<<")
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo --- # --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...") print(f"\n[PHASE 3] Recompute search using embedding server...")
@@ -107,7 +107,7 @@ def main():
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds") print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<") print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1): for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results # Compare results
print(f"\n--- Result comparison ---") print(f"\n--- Result comparison ---")
@@ -116,8 +116,8 @@ def main():
print("\nBasic search vs Recompute results:") print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))): for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score'] basic_score = results[i].score
recompute_score = recompute_results[i]['score'] recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score) score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}") print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")

View File

@@ -1,6 +1,7 @@
import faulthandler import faulthandler
faulthandler.enable() faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from llama_index.node_parser.docling import DoclingNodeParser from llama_index.node_parser.docling import DoclingNodeParser
@@ -69,17 +70,30 @@ if not INDEX_DIR.exists():
else: else:
print(f"--- Using existing index at {INDEX_DIR} ---") print(f"--- Using existing index at {INDEX_DIR} ---")
async def main(): async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...") print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
llm_config = {
"type": args.llm,
"model": args.model,
"host": args.host
}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?" query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发" query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
print(f"You: {query}") print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,complexity=32,beam_width=1) chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
print(f"Leann: {chat_response}") print(f"Leann: {chat_response}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Multi-Vector Aggregator for Fat Embeddings
==========================================
This module implements aggregation strategies for multi-vector embeddings,
similar to ColPali's approach where multiple patch vectors represent a single document.
Key features:
- MaxSim aggregation (take maximum similarity across patches)
- Voting-based aggregation (count patch matches)
- Weighted aggregation (attention-score weighted)
- Spatial clustering of matching patches
- Document-level result consolidation
"""
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import json
@dataclass
class PatchResult:
"""Represents a single patch search result."""
patch_id: int
image_name: str
image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float
attention_score: float
scale: float
metadata: Dict[str, Any]
@dataclass
class AggregatedResult:
"""Represents an aggregated document-level result."""
image_name: str
image_path: str
doc_score: float
patch_count: int
best_patch: PatchResult
all_patches: List[PatchResult]
aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None
class MultiVectorAggregator:
"""
Aggregates multiple patch-level results into document-level results.
"""
def __init__(self,
aggregation_method: str = "maxsim",
spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0):
"""
Initialize the aggregator.
Args:
aggregation_method: "maxsim", "voting", "weighted", or "mean"
spatial_clustering: Whether to cluster spatially close patches
cluster_distance_threshold: Distance threshold for spatial clustering
"""
self.aggregation_method = aggregation_method
self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self,
search_results: List[Dict[str, Any]],
top_k: int = 10) -> List[AggregatedResult]:
"""
Aggregate patch-level search results into document-level results.
Args:
search_results: List of search results from LeannSearcher
top_k: Number of top documents to return
Returns:
List of aggregated document results
"""
# Group results by image
image_groups = defaultdict(list)
for result in search_results:
metadata = result.metadata
if "image_name" in metadata and "patch_id" in metadata:
patch_result = PatchResult(
patch_id=metadata["patch_id"],
image_name=metadata["image_name"],
image_path=metadata["image_path"],
coordinates=tuple(metadata["coordinates"]),
score=result.score,
attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0),
metadata=metadata
)
image_groups[metadata["image_name"]].append(patch_result)
# Aggregate each image group
aggregated_results = []
for image_name, patches in image_groups.items():
if len(patches) == 0:
continue
agg_result = self._aggregate_image_patches(image_name, patches)
aggregated_results.append(agg_result)
# Sort by aggregated score and return top-k
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
"""Aggregate patches for a single image."""
if self.aggregation_method == "maxsim":
doc_score = max(patch.score for patch in patches)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "voting":
# Count patches above threshold
threshold = np.percentile([p.score for p in patches], 75)
doc_score = sum(1 for patch in patches if patch.score >= threshold)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "weighted":
# Weight by attention scores
total_weighted_score = sum(p.score * p.attention_score for p in patches)
total_weights = sum(p.attention_score for p in patches)
doc_score = total_weighted_score / max(total_weights, 1e-8)
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
elif self.aggregation_method == "mean":
doc_score = np.mean([patch.score for patch in patches])
best_patch = max(patches, key=lambda p: p.score)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
# Spatial clustering if enabled
spatial_clusters = None
if self.spatial_clustering:
spatial_clusters = self._cluster_patches_spatially(patches)
return AggregatedResult(
image_name=image_name,
image_path=patches[0].image_path,
doc_score=float(doc_score),
patch_count=len(patches),
best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters
)
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
"""Cluster patches that are spatially close to each other."""
if len(patches) <= 1:
return [patches]
clusters = []
remaining_patches = patches.copy()
while remaining_patches:
# Start new cluster with highest scoring remaining patch
seed_patch = max(remaining_patches, key=lambda p: p.score)
current_cluster = [seed_patch]
remaining_patches.remove(seed_patch)
# Add nearby patches to cluster
added_to_cluster = True
while added_to_cluster:
added_to_cluster = False
for patch in remaining_patches.copy():
if self._is_patch_nearby(patch, current_cluster):
current_cluster.append(patch)
remaining_patches.remove(patch)
added_to_cluster = True
clusters.append(current_cluster)
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
(patch_center[1] - cluster_center[1])**2)
if distance <= self.cluster_distance_threshold:
return True
return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of a patch."""
x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
"""Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80)
for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}")
# Show best patch
best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
# Show top patches
print(f" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
# Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality."""
print("=== Multi-Vector Aggregation Demo ===")
# Simulate some patch-level search results
# In real usage, these would come from LeannSearcher.search()
class MockResult:
def __init__(self, score, metadata):
self.score = score
self.metadata = metadata
# Simulate results for 2 images with multiple patches each
mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 3,
"coordinates": [100, 50, 224, 174], # Kitchen area
"attention_score": 0.92,
"scale": 1.0
}),
MockResult(0.78, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 7,
"coordinates": [200, 300, 324, 424], # Cat area
"attention_score": 0.88,
"scale": 1.0
}),
MockResult(0.72, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 12,
"coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.65, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0
}),
# Image 2: city_street.jpg - 3 patches
MockResult(0.68, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 2,
"coordinates": [300, 100, 424, 224], # Buildings
"attention_score": 0.80,
"scale": 1.0
}),
MockResult(0.62, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 8,
"coordinates": [100, 350, 224, 474], # Street level
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.55, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0
}),
]
# Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
aggregator = MultiVectorAggregator(
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__":
demo_aggregation()

157
examples/run_evaluation.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List, Dict, Any
import glob
import pickle
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannSearcher
# --- Configuration ---
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
# Ground truth files for different datasets
GROUND_TRUTH_FILES = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
}
# Old passages for different datasets
OLD_PASSAGES_GLOBS = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
}
# --- Helper Class to Load Original Passages ---
class OldPassageLoader:
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
def __init__(self, passages_glob: str):
self.jsonl_paths = sorted(glob.glob(passages_glob))
self.offsets = {}
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
print("Building offset map for original passages...")
for i, shard_path_str in enumerate(self.jsonl_paths):
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
if not old_idx_path.exists(): continue
with open(old_idx_path, 'rb') as f:
shard_offsets = pickle.load(f)
for pid, offset in shard_offsets.items():
self.offsets[str(pid)] = (i, offset)
print("Offset map for original passages is ready.")
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
pid = str(pid)
if pid not in self.offsets:
raise ValueError(f"Passage ID {pid} not found in offsets")
file_idx, offset = self.offsets[pid]
fp = self.fps[file_idx]
fp.seek(offset)
return json.loads(fp.readline())
def __del__(self):
for fp in self.fps:
fp.close()
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
queries.append(data['query'])
return queries
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
args = parser.parse_args()
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
# Detect dataset type from index path
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
dataset_type = "rpj_wiki"
print(f"INFO: Detected dataset type: {dataset_type}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(NQ_QUERIES_FILE)
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
print(f"INFO: Using ground truth file: {golden_results_file}")
print(f"INFO: Using old passages glob: {old_passages_glob}")
with open(golden_results_file, 'r') as f:
golden_results_data = json.load(f)
old_passage_loader = OldPassageLoader(old_passages_glob)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
golden_ids = golden_results_data["indices"][i][:args.top_k]
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -3,29 +3,25 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any, List
import contextlib import contextlib
import threading import pickle
import time
import atexit
import socket
import subprocess
import sys
from leann.embedding_server_manager import EmbeddingServerManager from leann.searcher_base import BaseSearcher
from leann.registry import register_backend from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface, LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendSearcherInterface
) )
from . import _diskannpy as diskannpy
METRIC_MAP = { def _get_diskann_metrics():
from . import _diskannpy as diskannpy
return {
"mips": diskannpy.Metric.INNER_PRODUCT, "mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2, "l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE, "cosine": diskannpy.Metric.COSINE,
} }
@contextlib.contextmanager @contextlib.contextmanager
def chdir(path): def chdir(path):
@@ -51,210 +47,87 @@ class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
# Pass essential metadata to the searcher
kwargs['meta'] = meta
return DiskannSearcher(index_path, **kwargs) return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface): class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Generate passages file for recompute mode, mirroring HNSW backend."""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
return
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)}
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
pass
def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True) index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32: if data.dtype != np.float32:
data = data.astype(np.float32) data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower() metric_enum = _get_diskann_metrics().get(build_kwargs.get("distance_metric", "mips").lower())
metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric.")
complexity = build_kwargs.get("complexity", 64)
graph_degree = build_kwargs.get("graph_degree", 32)
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
is_recompute = build_kwargs.get("is_recompute", False)
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try: try:
from . import _diskannpy as diskannpy
with chdir(index_dir): with chdir(index_dir):
diskannpy.build_disk_float_index( diskannpy.build_disk_float_index(
metric_enum, metric_enum, data_filename, index_prefix,
data_filename, build_kwargs.get("complexity", 64), build_kwargs.get("graph_degree", 32),
index_prefix, build_kwargs.get("search_memory_maximum", 4.0), build_kwargs.get("build_memory_maximum", 8.0),
complexity, build_kwargs.get("num_threads", 8), build_kwargs.get("pq_disk_bytes", 0), ""
graph_degree,
final_index_ram_limit,
indexing_ram_budget,
num_threads,
pq_disk_bytes,
codebook_prefix
) )
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
if is_recompute:
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
finally: finally:
temp_data_file = index_dir / data_filename temp_data_file = index_dir / data_filename
if temp_data_file.exists(): if temp_data_file.exists():
os.remove(temp_data_file) os.remove(temp_data_file)
class DiskannSearcher(LeannBackendSearcherInterface): class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
self.meta = kwargs.get("meta", {}) super().__init__(index_path, backend_module_name="leann_backend_diskann.embedding_server", **kwargs)
if not self.meta: from . import _diskannpy as diskannpy
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
dimensions = self.meta.get("dimensions") distance_metric = kwargs.get("distance_metric", "mips").lower()
if not dimensions: metric_enum = _get_diskann_metrics().get(distance_metric)
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(self.distance_metric)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model") self.num_threads = kwargs.get("num_threads", 8)
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.zmq_port = kwargs.get("zmq_port", 6666) self.zmq_port = kwargs.get("zmq_port", 6666)
try: full_index_prefix = str(self.index_dir / self.index_path.stem)
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", "" metric_enum, full_index_prefix, self.num_threads,
kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", ""
) )
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_diskann.embedding_server"
)
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256) recompute = kwargs.get("recompute_beighbor_embeddings", False)
beam_width = kwargs.get("beam_width", 4) if recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False) if not meta_file_path.exists():
skip_search_reorder = kwargs.get("skip_search_reorder", False) raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}")
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False) zmq_port = kwargs.get("zmq_port", self.zmq_port)
dedup_node_dis = kwargs.get("dedup_node_dis", False) self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
passages_file = kwargs.get("passages_file")
if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Automatically found passages file: {passages_file}")
if not passages_file:
raise RuntimeError(
f"Recompute mode is enabled, but no passages file was found. "
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
)
server_started = self.embedding_server_manager.start_server(
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=self.distance_metric,
passages_file=passages_file
)
if not server_started:
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try:
labels, distances = self._index.batch_search( labels, distances = self._index.batch_search(
query, query, query.shape[0], top_k,
query.shape[0], kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads,
top_k, kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False),
complexity, recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0),
beam_width, kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False)
self.num_threads,
USE_DEFERRED_FETCH,
skip_search_reorder,
recompute_beighbor_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
global_pruning
) )
return {"labels": labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self): string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server() return {"labels": string_labels, "distances": distances}

View File

@@ -15,6 +15,8 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
import zmq import zmq
import numpy as np import numpy as np
from pathlib import Path
import pickle
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
@@ -39,23 +41,113 @@ class SimplePassageLoader:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.passages_data) return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
""" """
Load passages from a JSON file Load passages using metadata file with PassageManager for lazy loading
Expected format: {"passage_id": "passage_text", ...}
""" """
if not os.path.exists(passages_file): # Load metadata to get passage sources
print(f"Warning: Passages file {passages_file} not found. Using empty loader.") with open(meta_file, 'r') as f:
return SimplePassageLoader() meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
from pathlib import Path
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try: try:
with open(passages_file, 'r', encoding='utf-8') as f: from leann.api import PassageManager
passages_data = json.load(f) passage_manager = PassageManager(meta['passage_sources'])
print(f"Loaded {len(passages_data)} passages from {passages_file}") finally:
return SimplePassageLoader(passages_data) sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
except Exception as e: except Exception as e:
print(f"Error loading passages from {passages_file}: {e}") raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
return SimplePassageLoader()
def __len__(self) -> int:
return len(self.label_map)
return LazyPassageLoader(passage_manager, label_map)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread( def create_embedding_server_thread(
zmq_port=5555, zmq_port=5555,
@@ -113,6 +205,19 @@ def create_embedding_server_thread(
# Load passages from file if provided # Load passages from file if provided
if passages_file and os.path.exists(passages_file): if passages_file and os.path.exists(passages_file):
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = load_passages_from_file(passages_file) passages = load_passages_from_file(passages_file)
else: else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.") print("WARNING: No passages file provided or file not found. Using an empty passage loader.")

View File

@@ -468,16 +468,27 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
# --- Write CSR HNSW graph data using unified function --- # --- Write CSR HNSW graph data using unified function ---
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...") print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
# Determine storage fourcc based on prune_embeddings # Determine storage fourcc and data based on prune_embeddings
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
if prune_embeddings: if prune_embeddings:
print(f" Pruning embeddings: Writing NULL storage marker.") print(f" Pruning embeddings: Writing NULL storage marker.")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b''
else:
# Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(f" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data")
else:
print(f" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b'' storage_data = b''
# Use the unified write function # Use the unified write function
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np, levels_np, compact_level_ptr, compact_node_offsets_np,
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'') compact_neighbors_data, output_storage_fourcc, storage_data)
# Clean up memory # Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np del assign_probas_np, cum_nneighbor_per_level_np, levels_np

View File

@@ -1,18 +1,12 @@
import numpy as np import numpy as np
import os import os
import json import json
import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any, List
import contextlib import pickle
import threading import shutil
import time
import atexit
import socket
import subprocess
import sys
from leann.embedding_server_manager import EmbeddingServerManager from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend from leann.registry import register_backend
@@ -38,97 +32,53 @@ class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f:
meta = json.load(f)
kwargs['meta'] = meta
return HNSWSearcher(index_path, **kwargs) return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface): class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs.copy() self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names ---
self.is_compact = self.build_params.setdefault("is_compact", True) self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True) self.is_recompute = self.build_params.setdefault("is_recompute", True)
# --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
self.external_storage_path = self.build_params.get("external_storage_path", None)
# --- Standard HNSW parameters ---
self.M = self.build_params.setdefault("M", 32) self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions") self.dimensions = self.build_params.get("dimensions")
if self.is_skip_neighbors and not self.is_compact: def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss from . import faiss
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True) index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32: if data.dtype != np.float32:
data = data.astype(np.float32) data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
metric_str = self.distance_metric.lower() label_map = {i: str_id for i, str_id in enumerate(ids)}
metric_enum = get_metric_map().get(metric_str) label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
M = self.M dim = self.dimensions or data.shape[1]
efConstruction = self.efConstruction index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
dim = self.dimensions index.hnsw.efConstruction = self.efConstruction
if not dim:
dim = data.shape[1]
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...") if self.distance_metric.lower() == "cosine":
try:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index.hnsw.efConstruction = efConstruction
if metric_str == "cosine":
faiss.normalize_L2(data) faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data)) index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index" index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file)) faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
def _convert_to_csr(self, index_file: Path): def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format""" """Convert built index to CSR format"""
try:
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard" mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...") print(f"INFO: Converting HNSW index to {mode_str} format...")
@@ -142,8 +92,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if success: if success:
print("✅ CSR conversion successful.") print("✅ CSR conversion successful.")
import shutil
# rename index_file to index_file.old
index_file_old = index_file.with_suffix(".old") index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old)) shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file)) shutil.move(str(csr_temp_file), str(index_file))
@@ -154,220 +102,53 @@ class HNSWBuilder(LeannBackendBuilderInterface):
os.remove(csr_temp_file) os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format") raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
except Exception as e: class HNSWSearcher(BaseSearcher):
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
"""
Robustly determines the index's storage status by parsing the file.
Returns:
A tuple (is_compact, is_pruned).
"""
if not index_file.exists():
return False, False
with open(index_file, 'rb') as f:
try:
def read_struct(fmt):
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
return struct.unpack(fmt, data)[0]
def skip_vector(element_size):
count = read_struct('<Q')
f.seek(count * element_size, 1)
# 1. Read up to the compact flag
read_struct('<I'); read_struct('<i'); read_struct('<q');
read_struct('<q'); read_struct('<q'); read_struct('<?')
metric_type = read_struct('<i')
if metric_type > 1: read_struct('<f')
skip_vector(8); skip_vector(4); skip_vector(4)
# 2. Check if there's a compact flag byte
# Try to read the compact flag, but handle both old and new formats
pos_before_compact = f.tell()
try:
is_compact = read_struct('<?')
print(f"INFO: Detected is_compact flag as: {is_compact}")
except (EOFError, struct.error):
# Old format without compact flag - assume non-compact
f.seek(pos_before_compact)
is_compact = False
print(f"INFO: No compact flag found, assuming is_compact=False")
# 3. Read storage FourCC to determine if pruned
is_pruned = False
try:
if is_compact:
# For compact, we need to skip pointers and scalars to get to the storage FourCC
skip_vector(8) # level_ptr
skip_vector(8) # node_offsets
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
storage_fourcc = read_struct('<I')
else:
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
pos_before_probe = f.tell()
flag_byte = f.read(1)
if not (flag_byte and flag_byte == b'\x00'):
f.seek(pos_before_probe)
skip_vector(8); skip_vector(4) # offsets, neighbors
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
# Now we are at the storage. The entire rest is storage blob.
storage_fourcc = struct.unpack('<I', f.read(4))[0]
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
if storage_fourcc == NULL_INDEX_FOURCC:
is_pruned = True
except (EOFError, struct.error):
# Cannot determine pruning status, assume not pruned
pass
print(f"INFO: Detected is_pruned as: {is_pruned}")
return is_compact, is_pruned
except (EOFError, struct.error) as e:
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
return False, False
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
super().__init__(index_path, backend_module_name="leann_backend_hnsw.hnsw_embedding_server", **kwargs)
from . import faiss from . import faiss
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("HNSWSearcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower() self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(self.distance_metric) metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model") self.is_compact, self.is_pruned = (
if not self.embedding_model: self.meta.get('is_compact', True),
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") self.meta.get('is_pruned', True)
)
path = Path(index_path) index_file = self.index_dir / f"{self.index_path.stem}.index"
self.index_dir = path.parent
self.index_prefix = path.stem
index_file = self.index_dir / f"{self.index_prefix}.index"
if not index_file.exists(): if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}") raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Validate configuration constraints
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig() hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact hnsw_config.is_compact = self.is_compact
# Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute: if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") raise RuntimeError("Index is pruned but recompute is disabled.")
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
if self.is_compact:
print("✅ Compact CSR format HNSW index loaded successfully.")
else:
print("✅ Standard HNSW index loaded successfully.")
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""Search using HNSW index with optional recompute functionality"""
from . import faiss from . import faiss
ef = kwargs.get("complexity", 200)
if self.is_pruned: if self.is_pruned:
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.") meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not self.embedding_model: if not meta_file_path.exists():
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}")
passages_file = kwargs.get("passages_file")
if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Automatically found passages file: {passages_file}")
else:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.")
zmq_port = kwargs.get("zmq_port", 5557) zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server( self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
port=zmq_port,
model_name=self.embedding_model,
passages_file=passages_file,
distance_metric=self.distance_metric
)
if not server_started:
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
if self.distance_metric == "cosine": if self.distance_metric == "cosine":
faiss.normalize_L2(query) faiss.normalize_L2(query)
try:
params = faiss.SearchParametersHNSW() params = faiss.SearchParametersHNSW()
params.efSearch = ef params.zmq_port = kwargs.get("zmq_port", 5557)
params.zmq_port = kwargs.get("zmq_port", self.zmq_port) params.efSearch = kwargs.get("complexity", 32)
params.beam_size = kwargs.get("beam_width", 1)
batch_size = query.shape[0] batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32) distances = np.empty((batch_size, top_k), dtype=np.float32)
@@ -375,12 +156,6 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances} string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
except Exception as e: return {"labels": string_labels, "distances": distances}
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
raise
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()

View File

@@ -56,23 +56,73 @@ class SimplePassageLoader:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.passages_data) return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
""" """
Load passages from a JSON file Load passages using metadata file with PassageManager for lazy loading
Expected format: {"passage_id": "passage_text", ...}
""" """
if not os.path.exists(passages_file): # Load metadata to get passage sources
print(f"Warning: Passages file {passages_file} not found. Using empty loader.") with open(meta_file, 'r') as f:
return SimplePassageLoader() meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
import importlib.util
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try: try:
with open(passages_file, 'r', encoding='utf-8') as f: from leann.api import PassageManager
passages_data = json.load(f) passage_manager = PassageManager(meta['passage_sources'])
print(f"Loaded {len(passages_data)} passages from {passages_file}") finally:
return SimplePassageLoader(passages_data) sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
else:
print(f"DEBUG: ID {int_id} not found in label_map")
return {"text": ""}
except Exception as e: except Exception as e:
print(f"Error loading passages from {passages_file}: {e}") print(f"DEBUG: Exception getting passage {passage_id}: {e}")
return SimplePassageLoader() return {"text": ""}
def __len__(self) -> int:
return len(self.label_map)
return LazyPassageLoader(passage_manager, label_map)
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
@@ -158,7 +208,20 @@ def create_hnsw_embedding_server(
passages = SimplePassageLoader(passages_data) passages = SimplePassageLoader(passages_data)
print(f"Using provided passages data: {len(passages)} passages") print(f"Using provided passages data: {len(passages)} passages")
elif passages_file: elif passages_file:
passages = load_passages_from_file(passages_file) # Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = SimplePassageLoader() # Use empty loader to avoid massive warnings
else: else:
passages = SimplePassageLoader() passages = SimplePassageLoader()
print("No passages provided, using empty loader") print("No passages provided, using empty loader")
@@ -227,6 +290,11 @@ def create_hnsw_embedding_server(
_is_bge_model = "bge" in model_name.lower() _is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch) batch_size = len(texts_batch)
# Validate no empty texts
for i, text in enumerate(texts_batch):
if not text or text.strip() == "":
raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}")
# E5 model preprocessing # E5 model preprocessing
if _is_e5_model: if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch] processed_texts_batch = [f"passage: {text}" for text in texts_batch]
@@ -373,14 +441,12 @@ def create_hnsw_embedding_server(
missing_ids = [] missing_ids = []
with lookup_timer.timing(): with lookup_timer.timing():
for nid in node_ids: for nid in node_ids:
try: print(f"DEBUG: Looking up passage ID {nid}")
txtinfo = passages[nid] txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "": if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text")
else:
txt = txtinfo["text"] txt = txtinfo["text"]
except (KeyError, IndexError): print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
texts.append(txt) texts.append(txt)
lookup_timer.print_elapsed() lookup_timer.print_elapsed()

View File

@@ -1,4 +1,4 @@
# 文件: packages/leann-backend-hnsw/pyproject.toml # packages/leann-backend-hnsw/pyproject.toml
[build-system] [build-system]
requires = ["scikit-build-core>=0.10", "numpy", "swig"] requires = ["scikit-build-core>=0.10", "numpy", "swig"]
@@ -10,7 +10,6 @@ version = "0.1.0"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"] dependencies = ["leann-core==0.1.0", "numpy"]
# 回归到最标准的 scikit-build-core 配置
[tool.scikit-build] [tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"] wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect" editable.mode = "redirect"

View File

@@ -1,250 +1,202 @@
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
"""
import json
import pickle
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
import uuid
import torch import torch
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from typing import List, Dict, Any, Optional
import numpy as np
import os
import json
from pathlib import Path
import openai
from dataclasses import dataclass, field
# --- Helper Functions for Embeddings --- # --- The Correct, Verified Embedding Logic from old_code.py ---
def _get_openai_client(): def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Initializes and returns an OpenAI client, ensuring the API key is set.""" """Computes embeddings using sentence-transformers for consistent results."""
api_key = os.getenv("OPENAI_API_KEY") try:
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
return openai.OpenAI(api_key=api_key)
def _is_openai_model(model_name: str) -> bool:
"""Checks if the model is likely an OpenAI embedding model."""
# This is a simple check, can be improved with a more robust list.
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
if _is_openai_model(model_name):
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=chunks)
embeddings = [item.embedding for item in response.data]
else:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ImportError as e:
raise RuntimeError(
f"sentence-transformers not available. Install with: pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name) model = SentenceTransformer(model_name)
model = model.half() model = model.half()
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...") print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
# use acclerater GPU or MAC GPU # use acclerater GPU or MAC GPU
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
model = model.to("cuda") model = model.to("cuda")
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
model = model.to("mps") model = model.to("mps")
embeddings = model.encode(chunks, show_progress_bar=True)
return np.asarray(embeddings, dtype=np.float32) # Generate embeddings
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
def _get_embedding_dimensions(model_name: str) -> int: return embeddings
"""Gets the embedding dimensions for a given model."""
print(f"INFO: Calculating dimensions for model '{model_name}'...")
if _is_openai_model(model_name):
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=["dummy text"])
return len(response.data[0].embedding)
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()
if dimension is None:
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
return dimension
# --- Core API Classes (Restored and Unchanged) ---
@dataclass @dataclass
class SearchResult: class SearchResult:
"""Represents a single search result.""" id: str
id: int
score: float score: float
text: str text: str
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
# --- Core Classes --- class PassageManager:
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
class LeannBuilder: class LeannBuilder:
""" def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs):
The builder is responsible for building the index, it will compute the embeddings and then build the index.
It will also save the metadata of the index.
"""
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.") raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory self.backend_factory = backend_factory
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions self.dimensions = dimensions
self.backend_kwargs = backend_kwargs self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = [] self.chunks: List[Dict[str, Any]] = []
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
self.chunks.append({"text": text, "metadata": metadata or {}}) if metadata is None: metadata = {}
passage_id = metadata.get('id', str(uuid.uuid4()))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)
def build_index(self, index_path: str): def build_index(self, index_path: str):
if not self.chunks: if not self.chunks: raise ValueError("No chunks added.")
raise ValueError("No chunks added. Use add_text() first.") if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0])
path = Path(index_path)
if self.dimensions is None: index_dir = path.parent
self.dimensions = _get_embedding_dimensions(self.embedding_model) index_name = path.name
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}") index_dir.mkdir(parents=True, exist_ok=True)
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
json.dump({"id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"]}, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) embeddings = compute_embeddings(texts_to_embed, self.embedding_model)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = self.backend_kwargs.copy() current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
current_backend_kwargs['dimensions'] = self.dimensions
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
build_kwargs = current_backend_kwargs.copy() leann_meta_path = index_dir / f"{index_name}.meta.json"
build_kwargs['chunks'] = self.chunks
builder_instance.build(embeddings, index_path, **build_kwargs)
index_dir = Path(index_path).parent
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json"
meta_data = { meta_data = {
"version": "0.1.0", "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
"backend_name": self.backend_name, "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs,
"embedding_model": self.embedding_model, "passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"num_chunks": len(self.chunks),
"chunks": self.chunks,
} }
with open(leann_meta_path, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2)
print(f"INFO: Leann metadata saved to {leann_meta_path}")
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute # Pruned only if compact and recompute
with open(leann_meta_path, 'w', encoding='utf-8') as f: json.dump(meta_data, f, indent=2)
class LeannSearcher: class LeannSearcher:
"""
The searcher is responsible for loading the index and performing the search.
It will also load the metadata of the index.
"""
def __init__(self, index_path: str, **backend_kwargs): def __init__(self, index_path: str, **backend_kwargs):
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json" meta_path_str = f"{index_path}.meta.json"
if not leann_meta_path.exists(): if not Path(meta_path_str).exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}. Was the index built with LeannBuilder?") with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
with open(leann_meta_path, 'r', encoding='utf-8') as f:
self.meta_data = json.load(f)
backend_name = self.meta_data['backend_name'] backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model'] self.embedding_model = self.meta_data['embedding_model']
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.") final_kwargs = {**self.meta_data.get('backend_kwargs', {}), **backend_kwargs}
final_kwargs = self.meta_data.get("backend_kwargs", {})
final_kwargs.update(backend_kwargs)
if 'dimensions' not in final_kwargs:
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
def search(self, query: str, top_k: int = 5, **search_kwargs): def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
query_embedding = _compute_embeddings([query], self.embedding_model) print(f"🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
query_embedding = compute_embeddings([query], self.embedding_model)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
search_kwargs['embedding_model'] = self.embedding_model
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
print(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = [] enriched_results = []
for label, dist in zip(results['labels'][0], results['distances'][0]): if 'labels' in results and 'distances' in results:
if label < len(self.meta_data['chunks']): print(f" Processing {len(results['labels'][0])} passage IDs:")
chunk_info = self.meta_data['chunks'][label] for i, (string_id, dist) in enumerate(zip(results['labels'][0], results['distances'][0])):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult( enriched_results.append(SearchResult(
id=label, id=string_id, score=dist, text=passage_data['text'], metadata=passage_data.get('metadata', {})
score=dist,
text=chunk_info['text'],
metadata=chunk_info.get('metadata', {})
)) ))
print(f" {i+1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text'][:60]}...")
except KeyError:
print(f" {i+1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!")
print(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results return enriched_results
from .chat import get_llm
class LeannChat: class LeannChat:
""" def __init__(self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs):
The chat is responsible for the conversation with the LLM.
It will use the searcher to get the results and then use the LLM to generate the response.
"""
def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs):
if backend_name is None:
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}.")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
meta_data = json.load(f)
backend_name = meta_data['backend_name']
self.searcher = LeannSearcher(index_path, **kwargs) self.searcher = LeannSearcher(index_path, **kwargs)
self.llm_model = llm_model self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs): def ask(self, question: str, top_k=5, **kwargs):
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
results = self.searcher.search(question, top_k=top_k, **kwargs) results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r.text for r in results]) context = "\n\n".join([r.text for r in results])
prompt = ( prompt = (
"Here is some retrieved context that might help answer your question:\n\n" "Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n" f"{context}\n\n"
f"Question: {question}\n\n" f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
print(f"DEBUG: Calling LLM with prompt: {prompt}...")
try:
client = _get_openai_client()
response = client.chat.completions.create(
model=self.llm_model,
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
except Exception as e:
print(f"ERROR: Failed to call OpenAI API: {e}")
return f"Error: Could not get a response from the LLM. {e}"
def start_interactive(self): def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)") print("\nLeann Chat started (type 'quit' to exit)")

View File

@@ -0,0 +1,229 @@
#!/usr/bin/env python3
"""
This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLMInterface(ABC):
"""Abstract base class for a generic Language Model (LLM) interface."""
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
# """
# Sends a prompt to the LLM and returns the generated text.
# Args:
# prompt: The input prompt for the LLM.
# **kwargs: Additional keyword arguments for the LLM backend.
# Returns:
# The response string from the LLM.
# """
pass
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
except ImportError:
raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.")
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": kwargs
}
logger.info(f"Sending request to Ollama: {payload}")
try:
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
# The response from Ollama can be a stream of JSON objects, handle this
response_parts = response.text.strip().split('\n')
full_response = ""
for part in response_parts:
if part:
json_part = json.loads(part)
full_response += json_part.get("response", "")
if json_part.get("done"):
break
return full_response
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return f"Error: Could not get a response from Ollama. Details: {e}"
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
try:
from transformers import pipeline
import torch
except ImportError:
raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.")
# Auto-detect device
if torch.cuda.is_available():
device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
device = "cpu"
logger.info("No GPU detected. Using CPU.")
self.pipeline = pipeline("text-generation", model=model_name, device=device)
def ask(self, prompt: str, **kwargs) -> str:
# Sensible defaults for text generation
params = {
"max_length": 500,
"num_return_sequences": 1,
**kwargs
}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
def ask(self, prompt: str, **kwargs) -> str:
# Default parameters for OpenAI
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
}
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
return f"Error: Could not get a response from OpenAI. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
def ask(self, prompt: str, **kwargs) -> str:
logger.info("Simulating LLM call...")
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
Args:
llm_config: A dictionary specifying the LLM type and its parameters.
Example: {"type": "ollama", "model": "llama3"}
{"type": "hf", "model": "distilgpt2"}
None (for simulation mode)
Returns:
An instance of an LLMInterface subclass.
"""
if llm_config is None:
logger.info("No LLM config provided, defaulting to simulated chat.")
return SimulatedChat()
llm_type = llm_config.get("type", "simulated")
model = llm_config.get("model")
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else:
raise ValueError(f"Unknown LLM type: '{llm_type}'")

View File

@@ -73,15 +73,17 @@ class EmbeddingServerManager:
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
cwd=project_root, cwd=project_root,
# stdout=subprocess.PIPE, stdout=subprocess.PIPE,
# stderr=subprocess.PIPE, stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True, text=True,
encoding='utf-8' encoding='utf-8',
bufsize=1, # Line buffered
universal_newlines=True
) )
self.server_port = port self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}") print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5 max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)): for _ in range(int(max_wait / wait_interval)):
if _check_port(port): if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.") print(f"✅ Embedding server is up and ready for this session.")
@@ -90,7 +92,7 @@ class EmbeddingServerManager:
return True return True
if self.server_process.poll() is not None: if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.") print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor() self._print_recent_output()
return False return False
time.sleep(wait_interval) time.sleep(wait_interval)
@@ -102,19 +104,32 @@ class EmbeddingServerManager:
print(f"❌ ERROR: Failed to start embedding server process: {e}") print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False return False
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
# Read any available output
import select
import sys
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
def _log_monitor(self): def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr.""" """Monitors and prints the server's stdout and stderr."""
if not self.server_process: if not self.server_process:
return return
try: try:
if self.server_process.stdout: if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''): while True:
print(f"[{self.backend_module_name} LOG]: {line.strip()}") line = self.server_process.stdout.readline()
self.server_process.stdout.close() if not line:
if self.server_process.stderr: break
for line in iter(self.server_process.stderr.readline, ''): print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True)
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e: except Exception as e:
print(f"Log monitor error: {e}") print(f"Log monitor error: {e}")

View File

@@ -0,0 +1,97 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, List
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendSearcherInterface
class BaseSearcher(LeannBackendSearcherInterface, ABC):
"""
Abstract base class for Leann searchers, containing common logic for
loading metadata, managing embedding servers, and handling file paths.
"""
def __init__(self, index_path: str, backend_module_name: str, **kwargs):
"""
Initializes the BaseSearcher.
Args:
index_path: Path to the Leann index file (e.g., '.../my_index.leann').
backend_module_name: The specific embedding server module to use
(e.g., 'leann_backend_hnsw.hnsw_embedding_server').
**kwargs: Additional keyword arguments.
"""
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.meta = kwargs.get("meta", self._load_meta())
if not self.meta:
raise ValueError("Searcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
)
def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
return pickle.load(f)
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
"""
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
server_started = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {kwargs.get('zmq_port')}")
@abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.
Must be implemented by subclasses.
"""
pass
def __del__(self):
"""Ensures the embedding server is stopped when the searcher is destroyed."""
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()

View File

@@ -1,24 +0,0 @@
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.embeddings import resolve_embed_model
# Check the default embedding model
embed_model = resolve_embed_model("default")
print(f"Default embedding model: {embed_model}")
# Create a simple test document
doc = Document(text="This is a test document")
# Get embedding dimension
try:
# Test embedding
test_embedding = embed_model.get_text_embedding("test")
print(f"Embedding dimension: {len(test_embedding)}")
print(f"Embedding type: {type(test_embedding)}")
except Exception as e:
print(f"Error getting embedding: {e}")
# Alternative way to check dimension
if hasattr(embed_model, 'embed_dim'):
print(f"Model embed_dim attribute: {embed_model.embed_dim}")
elif hasattr(embed_model, 'dimension'):
print(f"Model dimension attribute: {embed_model.dimension}")

View File

@@ -1,20 +0,0 @@
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.embeddings import resolve_embed_model
# Check the default embedding model
embed_model = resolve_embed_model("default")
print(f"Default embedding model: {embed_model}")
# Create a simple test
doc = Document(text="This is a test document")
index = VectorStoreIndex.from_documents([doc])
# Get the embedding model from the index
index_embed_model = index.embed_model
print(f"Index embedding model: {index_embed_model}")
# Check if it's OpenAI or local
if hasattr(index_embed_model, 'model_name'):
print(f"Model name: {index_embed_model.model_name}")
else:
print(f"Embedding model type: {type(index_embed_model)}")

6620
uv.lock generated
View File

File diff suppressed because it is too large Load Diff