merge main
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
import faulthandler
|
||||
faulthandler.enable()
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.node_parser.docling import DoclingNodeParser
|
||||
@@ -7,7 +10,7 @@ import asyncio
|
||||
import os
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import leann_backend_diskann # Import to ensure backend registration
|
||||
import leann_backend_hnsw # Import to ensure backend registration
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,7 +24,7 @@ file_extractor: dict[str, BaseReader] = {
|
||||
".xlsx": reader,
|
||||
}
|
||||
node_parser = DoclingNodeParser(
|
||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=256)
|
||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
@@ -32,10 +35,8 @@ documents = SimpleDirectoryReader(
|
||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
# Extract text from documents and prepare for Leann
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# DoclingNodeParser returns Node objects, which have a text attribute
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.text)
|
||||
@@ -43,32 +44,35 @@ for doc in documents:
|
||||
INDEX_DIR = Path("./test_pdf_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever", # Using a common sentence transformer model
|
||||
graph_degree=32,
|
||||
complexity=64
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# CSR compact mode with recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
async def main():
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
# query = "What is the Off-policy training in RL?"
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
@@ -3,11 +3,17 @@ Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
def main():
|
||||
print("=== Leann Simple Demo ===")
|
||||
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
||||
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||
print()
|
||||
|
||||
# Sample knowledge base
|
||||
@@ -24,10 +30,11 @@ def main():
|
||||
|
||||
print("1. Building index (no embeddings stored)...")
|
||||
builder = LeannBuilder(
|
||||
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
||||
prune_ratio=0.7, # Keep 30% of connections
|
||||
embedding_model=args.embedding_model,
|
||||
backend_name="hnsw",
|
||||
)
|
||||
builder.add_chunks(chunks)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk)
|
||||
builder.build_index("demo_knowledge.leann")
|
||||
print()
|
||||
|
||||
@@ -49,14 +56,7 @@ def main():
|
||||
print(f" Text: {result.text[:100]}...")
|
||||
print()
|
||||
|
||||
print("3. Memory stats:")
|
||||
stats = searcher.get_memory_stats()
|
||||
print(f" Cache size: {stats.embedding_cache_size}")
|
||||
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
|
||||
print(f" Total chunks: {stats.total_chunks}")
|
||||
print()
|
||||
|
||||
print("4. Interactive chat demo:")
|
||||
print("3. Interactive chat demo:")
|
||||
print(" (Note: Requires OpenAI API key for real responses)")
|
||||
|
||||
chat = LeannChat("demo_knowledge.leann")
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
print("Initializing leann-backend-diskann...")
|
||||
|
||||
try:
|
||||
from .diskann_backend import DiskannBackend
|
||||
print("INFO: DiskANN backend loaded successfully")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import DiskANN backend: {e}")
|
||||
@@ -143,20 +143,16 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(meta.get("embedding_model"))
|
||||
dimensions = model.get_sentence_embedding_dimension()
|
||||
kwargs['dimensions'] = dimensions
|
||||
except ImportError:
|
||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
||||
|
||||
dimensions = meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
||||
|
||||
kwargs['dimensions'] = dimensions
|
||||
return DiskannSearcher(index_path, **kwargs)
|
||||
|
||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
|
||||
543
packages/leann-backend-hnsw/leann_backend_hnsw/convert_to_csr.py
Normal file
543
packages/leann-backend-hnsw/leann_backend_hnsw/convert_to_csr.py
Normal file
@@ -0,0 +1,543 @@
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import time
|
||||
# --- FourCCs (add more if needed) ---
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
||||
|
||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
||||
|
||||
# --- Helper functions for reading/writing binary data ---
|
||||
|
||||
def read_struct(f, fmt):
|
||||
"""Reads data according to the struct format."""
|
||||
size = struct.calcsize(fmt)
|
||||
data = f.read(size)
|
||||
if len(data) != size:
|
||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
|
||||
return struct.unpack(fmt, data)[0]
|
||||
|
||||
def read_vector_raw(f, element_fmt_char):
|
||||
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
||||
count = -1 # Initialize count
|
||||
total_bytes = -1 # Initialize total_bytes
|
||||
try:
|
||||
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
||||
element_size = struct.calcsize(element_fmt_char)
|
||||
# --- FIX for MemoryError: Check for unreasonably large count ---
|
||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||
if count > max_reasonable_count or count < 0:
|
||||
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
|
||||
|
||||
total_bytes = count * element_size
|
||||
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
|
||||
|
||||
data_bytes = f.read(total_bytes)
|
||||
|
||||
if len(data_bytes) != total_bytes:
|
||||
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
|
||||
return count, data_bytes
|
||||
except (MemoryError, OverflowError) as e:
|
||||
# Add context to the error message
|
||||
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
|
||||
raise e # Re-raise the original error type
|
||||
|
||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||
"""Reads a vector into a NumPy array."""
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
|
||||
try:
|
||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||
if count > 0 and len(data_bytes) > 0:
|
||||
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
||||
if arr.size != count:
|
||||
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
|
||||
return arr
|
||||
elif count == 0:
|
||||
return np.array([], dtype=np_dtype)
|
||||
else:
|
||||
raise ValueError("Read zero bytes but count > 0.")
|
||||
except MemoryError as e:
|
||||
# Now count should be defined (or -1 if error was in read_struct)
|
||||
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
|
||||
raise e
|
||||
except Exception as e: # Catch other potential errors like ValueError
|
||||
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
|
||||
raise e
|
||||
|
||||
|
||||
def write_numpy_vector(f, arr, struct_fmt_char):
|
||||
"""Writes a NumPy array as a vector (size followed by data)."""
|
||||
count = arr.size
|
||||
f.write(struct.pack('<Q', count))
|
||||
try:
|
||||
expected_dtype = np.dtype(struct_fmt_char)
|
||||
if arr.dtype != expected_dtype:
|
||||
data_to_write = arr.astype(expected_dtype).tobytes()
|
||||
else:
|
||||
data_to_write = arr.tobytes()
|
||||
f.write(data_to_write)
|
||||
del data_to_write # Hint GC
|
||||
except MemoryError as e:
|
||||
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
|
||||
raise e
|
||||
|
||||
def write_list_vector(f, lst, struct_fmt_char):
|
||||
"""Writes a Python list as a vector iteratively."""
|
||||
count = len(lst)
|
||||
f.write(struct.pack('<Q', count))
|
||||
fmt = '<' + struct_fmt_char
|
||||
chunk_size = 1024 * 1024
|
||||
element_size = struct.calcsize(fmt)
|
||||
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
||||
try:
|
||||
buffer = bytearray(chunk_size * element_size)
|
||||
except MemoryError:
|
||||
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
|
||||
raise
|
||||
buffer_count = 0
|
||||
|
||||
for i, item in enumerate(lst):
|
||||
try:
|
||||
offset = buffer_count * element_size
|
||||
struct.pack_into(fmt, buffer, offset, item)
|
||||
buffer_count += 1
|
||||
|
||||
if buffer_count == chunk_size or i == count - 1:
|
||||
f.write(buffer[:buffer_count * element_size])
|
||||
buffer_count = 0
|
||||
|
||||
except struct.error as e:
|
||||
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
|
||||
raise e
|
||||
|
||||
|
||||
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
||||
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
||||
if level < 0: return 0
|
||||
if level < len(cum_nneighbor_per_level_np):
|
||||
return cum_nneighbor_per_level_np[level]
|
||||
else:
|
||||
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
||||
|
||||
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, storage_fourcc, storage_data):
|
||||
"""Write HNSW data in compact format following C++ read order exactly."""
|
||||
# Write IndexHNSW Header
|
||||
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['d']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
|
||||
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
|
||||
if original_hnsw_data['metric_type'] > 1:
|
||||
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
|
||||
|
||||
# Write HNSW struct parts (standard order)
|
||||
write_numpy_vector(f_out, assign_probas_np, 'd')
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
|
||||
write_numpy_vector(f_out, levels_np, 'i')
|
||||
|
||||
# Write compact format flag
|
||||
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
|
||||
|
||||
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
||||
if isinstance(compact_level_ptr, np.ndarray):
|
||||
write_numpy_vector(f_out, compact_level_ptr, 'Q')
|
||||
else:
|
||||
write_list_vector(f_out, compact_level_ptr, 'Q')
|
||||
|
||||
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
|
||||
|
||||
# Write HNSW scalar parameters
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
|
||||
|
||||
# Write storage fourcc (this determines how to read what follows)
|
||||
f_out.write(struct.pack('<I', storage_fourcc))
|
||||
|
||||
# Write compact neighbors data AFTER storage fourcc
|
||||
write_list_vector(f_out, compact_neighbors_data, 'i')
|
||||
|
||||
# Write storage data if not NULL (only after neighbors)
|
||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||
f_out.write(storage_data)
|
||||
|
||||
|
||||
# --- Main Conversion Logic ---
|
||||
|
||||
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
||||
"""
|
||||
Converts an HNSW graph file to the CSR format.
|
||||
Supports both original and already-compact formats (backward compatibility).
|
||||
|
||||
Args:
|
||||
input_filename: Input HNSW index file
|
||||
output_filename: Output CSR index file
|
||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||
"""
|
||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||
start_time = time.time()
|
||||
original_hnsw_data = {}
|
||||
neighbors_np = None # Initialize to allow check in finally block
|
||||
try:
|
||||
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
|
||||
|
||||
# --- Read IndexHNSW FourCC and Header ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
||||
# ... (Keep the header reading logic as before) ...
|
||||
hnsw_index_fourcc = read_struct(f_in, '<I')
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
|
||||
return False
|
||||
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
|
||||
original_hnsw_data['d'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
|
||||
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['metric_arg'] = 0.0
|
||||
if original_hnsw_data['metric_type'] > 1:
|
||||
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
|
||||
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
|
||||
|
||||
|
||||
# --- Read original HNSW struct data ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
|
||||
gc.collect()
|
||||
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
|
||||
gc.collect()
|
||||
|
||||
levels_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
||||
gc.collect()
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data['ntotal']:
|
||||
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
|
||||
original_hnsw_data['ntotal'] = ntotal
|
||||
|
||||
# --- Check for compact format flag ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
||||
pos_before_compact = f_in.tell()
|
||||
try:
|
||||
is_compact_flag = read_struct(f_in, '<?')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
||||
|
||||
if is_compact_flag:
|
||||
# Input is already in compact format - read compact data
|
||||
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
|
||||
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
|
||||
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
|
||||
|
||||
# Read scalar parameters
|
||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||
|
||||
# Read storage fourcc
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
|
||||
|
||||
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
||||
# Read compact neighbors data
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
|
||||
# Skip storage data and write with NULL marker
|
||||
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
elif not prune_embeddings:
|
||||
# Read and preserve compact neighbors and storage
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
|
||||
# Read remaining storage data
|
||||
storage_data = f_in.read()
|
||||
else:
|
||||
# Already pruned (NULL storage)
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
storage_data = b''
|
||||
|
||||
# Write the updated compact format
|
||||
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
||||
return True
|
||||
|
||||
else:
|
||||
# is_compact=False, rewind and read original format
|
||||
f_in.seek(pos_before_compact)
|
||||
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
|
||||
|
||||
except EOFError:
|
||||
# No compact flag found, assume original format
|
||||
f_in.seek(pos_before_compact)
|
||||
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
|
||||
|
||||
# --- Handle potential extra byte in original format (like C++ code) ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
|
||||
pos_before_probe = f_in.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
|
||||
if suspected_flag == 0x00:
|
||||
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
|
||||
elif suspected_flag == 0x01:
|
||||
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
|
||||
raise ValueError("Inconsistent compact flag state")
|
||||
else:
|
||||
# Rewind - this byte is part of offsets data
|
||||
f_in.seek(pos_before_probe)
|
||||
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
|
||||
except EOFError:
|
||||
f_in.seek(pos_before_probe)
|
||||
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
|
||||
|
||||
# --- Read original format data ---
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
||||
if len(offsets_np) != ntotal + 1:
|
||||
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
|
||||
gc.collect()
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
||||
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
||||
if neighbors_np.size != expected_neighbors_size:
|
||||
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
|
||||
gc.collect()
|
||||
|
||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
||||
storage_fourcc = None
|
||||
try:
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
|
||||
except EOFError:
|
||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||
except Exception as e:
|
||||
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
|
||||
|
||||
|
||||
# --- Perform Conversion ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
||||
|
||||
# Use lists for potentially huge data, np for offsets
|
||||
compact_neighbors_data = []
|
||||
compact_level_ptr = []
|
||||
compact_node_offsets_np = np.zeros(ntotal + 1, dtype=np.uint64)
|
||||
|
||||
current_level_ptr_idx = 0
|
||||
current_data_idx = 0
|
||||
total_valid_neighbors_counted = 0 # For validation
|
||||
|
||||
# Optimize calculation by getting slices once per node if possible
|
||||
for i in range(ntotal):
|
||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||
progress = (i / ntotal) * 100
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
|
||||
|
||||
node_max_level = levels_np[i] - 1
|
||||
if node_max_level < -1: node_max_level = -1
|
||||
|
||||
node_ptr_start_index = current_level_ptr_idx
|
||||
compact_node_offsets_np[i] = node_ptr_start_index
|
||||
|
||||
original_offset_start = offsets_np[i]
|
||||
num_pointers_expected = (node_max_level + 1) + 1
|
||||
|
||||
for level in range(node_max_level + 1):
|
||||
compact_level_ptr.append(current_data_idx)
|
||||
|
||||
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
|
||||
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
|
||||
|
||||
begin_orig = int(begin_orig_np)
|
||||
end_orig = int(end_orig_np)
|
||||
|
||||
neighbors_len = len(neighbors_np) # Cache length
|
||||
begin_orig = min(max(0, begin_orig), neighbors_len)
|
||||
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
||||
|
||||
if begin_orig < end_orig:
|
||||
# Slicing creates a copy, could be memory intensive for large M
|
||||
# Consider iterating if memory becomes an issue here
|
||||
level_neighbors_slice = neighbors_np[begin_orig:end_orig]
|
||||
valid_neighbors_mask = level_neighbors_slice >= 0
|
||||
num_valid = np.count_nonzero(valid_neighbors_mask)
|
||||
|
||||
if num_valid > 0:
|
||||
# Append valid neighbors
|
||||
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
|
||||
current_data_idx += num_valid
|
||||
total_valid_neighbors_counted += num_valid
|
||||
|
||||
|
||||
compact_level_ptr.append(current_data_idx)
|
||||
current_level_ptr_idx += num_pointers_expected
|
||||
|
||||
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
||||
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
|
||||
|
||||
# --- Validation Checks ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
||||
valid_check_passed = True
|
||||
# Check 1: Total valid neighbors count
|
||||
print(f" Checking total valid neighbor count...")
|
||||
expected_valid_count = np.sum(neighbors_np >= 0)
|
||||
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
||||
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
if expected_valid_count != len(compact_neighbors_data):
|
||||
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
else:
|
||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||
|
||||
# Check 2: Final pointer indices consistency
|
||||
print(f" Checking final pointer indices...")
|
||||
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
||||
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
|
||||
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
else:
|
||||
print(f" OK: Final pointers match data size.")
|
||||
|
||||
if not valid_check_passed:
|
||||
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
|
||||
# Optional: Exit here if validation fails
|
||||
# return False
|
||||
|
||||
# --- Explicitly delete large intermediate arrays ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
|
||||
del neighbors_np
|
||||
del offsets_np
|
||||
gc.collect()
|
||||
|
||||
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
|
||||
|
||||
# --- Write CSR HNSW graph data using unified function ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
||||
|
||||
# Determine storage fourcc based on prune_embeddings
|
||||
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
|
||||
if prune_embeddings:
|
||||
print(f" Pruning embeddings: Writing NULL storage marker.")
|
||||
storage_data = b''
|
||||
|
||||
# Use the unified write function
|
||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'')
|
||||
|
||||
# Clean up memory
|
||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
|
||||
gc.collect()
|
||||
|
||||
end_time = time.time()
|
||||
print(f"[{end_time - start_time:.2f}s] Conversion complete.")
|
||||
return True
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||
return False
|
||||
except MemoryError as e:
|
||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
||||
# Clean up potentially partially written output file?
|
||||
try: os.remove(output_filename)
|
||||
except OSError: pass
|
||||
return False
|
||||
except EOFError as e:
|
||||
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
|
||||
try: os.remove(output_filename)
|
||||
except OSError: pass
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError: pass
|
||||
return False
|
||||
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
||||
finally:
|
||||
if 'neighbors_np' in locals() and neighbors_np is not None:
|
||||
del neighbors_np
|
||||
gc.collect()
|
||||
|
||||
|
||||
# --- Script Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
|
||||
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
||||
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
|
||||
parser.add_argument("--prune-embeddings", action="store_true", default=True,
|
||||
help="Prune embedding storage (write NULL storage marker)")
|
||||
parser.add_argument("--keep-embeddings", action="store_true",
|
||||
help="Keep embedding storage (overrides --prune-embeddings)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.input_index_file):
|
||||
print(f"Error: Input file not found: {args.input_index_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
||||
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
||||
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, Any
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
@@ -12,9 +12,7 @@ import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# 文件: packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
|
||||
|
||||
# ... (其他 import 保持不变) ...
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
|
||||
from leann.registry import register_backend
|
||||
from leann.interface import (
|
||||
@@ -28,7 +26,7 @@ def get_metric_map():
|
||||
return {
|
||||
"mips": faiss.METRIC_INNER_PRODUCT,
|
||||
"l2": faiss.METRIC_L2,
|
||||
"cosine": faiss.METRIC_INNER_PRODUCT, # Will need normalization
|
||||
"cosine": faiss.METRIC_INNER_PRODUCT,
|
||||
}
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
@@ -46,7 +44,7 @@ class HNSWEmbeddingServerManager:
|
||||
self.server_port = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None):
|
||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
|
||||
"""
|
||||
Start the HNSW embedding server process.
|
||||
|
||||
@@ -54,6 +52,7 @@ class HNSWEmbeddingServerManager:
|
||||
port: ZMQ port for the server
|
||||
model_name: Name of the embedding model to use
|
||||
passages_file: Optional path to passages JSON file
|
||||
distance_metric: The distance metric to use
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
|
||||
@@ -69,12 +68,12 @@ class HNSWEmbeddingServerManager:
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m", "packages.leann-backend-hnsw.src.leann_backend_hnsw.hnsw_embedding_server",
|
||||
"-m", "leann_backend_hnsw.hnsw_embedding_server",
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name
|
||||
"--model-name", model_name,
|
||||
"--distance-metric", distance_metric
|
||||
]
|
||||
|
||||
# Add passages file if provided
|
||||
if passages_file:
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
|
||||
@@ -153,26 +152,42 @@ class HNSWBackend(LeannBackendFactoryInterface):
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(meta.get("embedding_model"))
|
||||
dimensions = model.get_sentence_embedding_dimension()
|
||||
kwargs['dimensions'] = dimensions
|
||||
except ImportError:
|
||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
||||
|
||||
dimensions = meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
||||
|
||||
kwargs['dimensions'] = dimensions
|
||||
return HNSWSearcher(index_path, **kwargs)
|
||||
|
||||
class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
self.build_params = kwargs.copy()
|
||||
|
||||
# --- Configuration defaults with standardized names ---
|
||||
self.is_compact = self.build_params.setdefault("is_compact", True)
|
||||
self.is_recompute = self.build_params.setdefault("is_recompute", True)
|
||||
|
||||
# --- Additional Options ---
|
||||
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
|
||||
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
|
||||
self.external_storage_path = self.build_params.get("external_storage_path", None)
|
||||
|
||||
# --- Standard HNSW parameters ---
|
||||
self.M = self.build_params.setdefault("M", 32)
|
||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||
self.dimensions = self.build_params.get("dimensions")
|
||||
|
||||
if self.is_skip_neighbors and not self.is_compact:
|
||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
||||
|
||||
if self.is_recompute and not self.is_compact:
|
||||
raise ValueError("is_recompute requires is_compact=True for efficiency")
|
||||
|
||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
||||
"""Build HNSW index using FAISS"""
|
||||
@@ -189,97 +204,289 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if not data.flags['C_CONTIGUOUS']:
|
||||
data = np.ascontiguousarray(data)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
||||
metric_str = self.distance_metric.lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
print('metric_enum', metric_enum,' metric_str', metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
# HNSW parameters
|
||||
M = build_kwargs.get("M", 32) # Max connections per layer
|
||||
efConstruction = build_kwargs.get("efConstruction", 200) # Size of the dynamic candidate list for construction
|
||||
dim = data.shape[1]
|
||||
M = self.M
|
||||
efConstruction = self.efConstruction
|
||||
dim = self.dimensions
|
||||
if not dim:
|
||||
dim = data.shape[1]
|
||||
|
||||
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||
|
||||
try:
|
||||
# Create HNSW index
|
||||
if metric_enum == faiss.METRIC_INNER_PRODUCT:
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
else: # L2
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
|
||||
# Set construction parameters
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
|
||||
# Normalize vectors if using cosine similarity
|
||||
if metric_str == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
|
||||
# Add vectors to index
|
||||
print('starting to add vectors to index')
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
print('vectors added to index')
|
||||
|
||||
# Save index
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
print(f"✅ HNSW index built successfully at '{index_file}'")
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
|
||||
if self.is_recompute:
|
||||
self._generate_passages_file(index_dir, index_prefix, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
|
||||
raise
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
try:
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
success = convert_hnsw_graph_to_csr(
|
||||
str(index_file),
|
||||
str(csr_temp_file),
|
||||
prune_embeddings=self.is_recompute
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ CSR conversion successful.")
|
||||
import shutil
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||
else:
|
||||
# Clean up and fail fast
|
||||
if csr_temp_file.exists():
|
||||
os.remove(csr_temp_file)
|
||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
|
||||
raise
|
||||
|
||||
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
|
||||
"""Generate passages file for recompute mode"""
|
||||
try:
|
||||
chunks = kwargs.get('chunks', [])
|
||||
if not chunks:
|
||||
print("INFO: No chunks data provided, skipping passages file generation")
|
||||
return
|
||||
|
||||
# Generate node_id to text mapping
|
||||
passages_data = {}
|
||||
for node_id, chunk in enumerate(chunks):
|
||||
passages_data[str(node_id)] = chunk["text"]
|
||||
|
||||
# Save passages file
|
||||
passages_file = index_dir / f"{index_prefix}.passages.json"
|
||||
with open(passages_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(passages_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
|
||||
# Don't raise - this is not critical for index building
|
||||
pass
|
||||
|
||||
class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
|
||||
"""
|
||||
Robustly determines the index's storage status by parsing the file.
|
||||
|
||||
Returns:
|
||||
A tuple (is_compact, is_pruned).
|
||||
"""
|
||||
if not index_file.exists():
|
||||
return False, False
|
||||
|
||||
with open(index_file, 'rb') as f:
|
||||
try:
|
||||
def read_struct(fmt):
|
||||
size = struct.calcsize(fmt)
|
||||
data = f.read(size)
|
||||
if len(data) != size:
|
||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
|
||||
return struct.unpack(fmt, data)[0]
|
||||
|
||||
def skip_vector(element_size):
|
||||
count = read_struct('<Q')
|
||||
f.seek(count * element_size, 1)
|
||||
|
||||
# 1. Read up to the compact flag
|
||||
read_struct('<I'); read_struct('<i'); read_struct('<q');
|
||||
read_struct('<q'); read_struct('<q'); read_struct('<?')
|
||||
metric_type = read_struct('<i')
|
||||
if metric_type > 1: read_struct('<f')
|
||||
skip_vector(8); skip_vector(4); skip_vector(4)
|
||||
|
||||
# 2. Check if there's a compact flag byte
|
||||
# Try to read the compact flag, but handle both old and new formats
|
||||
pos_before_compact = f.tell()
|
||||
try:
|
||||
is_compact = read_struct('<?')
|
||||
print(f"INFO: Detected is_compact flag as: {is_compact}")
|
||||
except (EOFError, struct.error):
|
||||
# Old format without compact flag - assume non-compact
|
||||
f.seek(pos_before_compact)
|
||||
is_compact = False
|
||||
print(f"INFO: No compact flag found, assuming is_compact=False")
|
||||
|
||||
# 3. Read storage FourCC to determine if pruned
|
||||
is_pruned = False
|
||||
try:
|
||||
if is_compact:
|
||||
# For compact, we need to skip pointers and scalars to get to the storage FourCC
|
||||
skip_vector(8) # level_ptr
|
||||
skip_vector(8) # node_offsets
|
||||
read_struct('<i'); read_struct('<i'); read_struct('<i');
|
||||
read_struct('<i'); read_struct('<i')
|
||||
storage_fourcc = read_struct('<I')
|
||||
else:
|
||||
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
|
||||
pos_before_probe = f.tell()
|
||||
flag_byte = f.read(1)
|
||||
if not (flag_byte and flag_byte == b'\x00'):
|
||||
f.seek(pos_before_probe)
|
||||
skip_vector(8); skip_vector(4) # offsets, neighbors
|
||||
read_struct('<i'); read_struct('<i'); read_struct('<i');
|
||||
read_struct('<i'); read_struct('<i')
|
||||
# Now we are at the storage. The entire rest is storage blob.
|
||||
storage_fourcc = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
||||
if storage_fourcc == NULL_INDEX_FOURCC:
|
||||
is_pruned = True
|
||||
except (EOFError, struct.error):
|
||||
# Cannot determine pruning status, assume not pruned
|
||||
pass
|
||||
|
||||
print(f"INFO: Detected is_pruned as: {is_pruned}")
|
||||
return is_compact, is_pruned
|
||||
|
||||
except (EOFError, struct.error) as e:
|
||||
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
|
||||
return False, False
|
||||
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
from . import faiss
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
|
||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
||||
# Store configuration and paths for later use
|
||||
self.config = kwargs.copy()
|
||||
self.config["index_path"] = index_path
|
||||
self.index_dir = index_dir
|
||||
self.index_prefix = index_prefix
|
||||
|
||||
metric_str = self.config.get("distance_metric", "mips").lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
dimensions = kwargs.get("dimensions")
|
||||
dimensions = self.config.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Vector dimension not provided to HNSWSearcher.")
|
||||
|
||||
try:
|
||||
# Load FAISS HNSW index
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||
|
||||
self._index = faiss.read_index(str(index_file))
|
||||
self.metric_str = metric_str
|
||||
self.embedding_server_manager = HNSWEmbeddingServerManager()
|
||||
print("✅ HNSW index loaded successfully.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to load HNSW index. Exception: {e}")
|
||||
raise
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
||||
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
|
||||
|
||||
# Validate configuration constraints
|
||||
if not self.is_compact and self.config.get("is_skip_neighbors", False):
|
||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
||||
|
||||
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"):
|
||||
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
|
||||
|
||||
hnsw_config = faiss.HNSWIndexConfig()
|
||||
hnsw_config.is_compact = self.is_compact
|
||||
|
||||
# Apply additional configuration options with strict validation
|
||||
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
|
||||
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
|
||||
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
|
||||
hnsw_config.external_storage_path = self.config.get("external_storage_path")
|
||||
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
|
||||
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
||||
|
||||
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
|
||||
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
if self.is_compact:
|
||||
print("✅ Compact CSR format HNSW index loaded successfully.")
|
||||
else:
|
||||
print("✅ Standard HNSW index loaded successfully.")
|
||||
|
||||
self.metric_str = metric_str
|
||||
self.embedding_server_manager = HNSWEmbeddingServerManager()
|
||||
|
||||
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path:
|
||||
"""Get the appropriate index file path based on format"""
|
||||
# We always use the same filename now, format is detected internally
|
||||
return index_dir / f"{index_prefix}.index"
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||
"""Search using HNSW index with optional recompute functionality"""
|
||||
from . import faiss
|
||||
ef = kwargs.get("ef", 200) # Size of the dynamic candidate list for search
|
||||
# Merge config with search-time kwargs
|
||||
search_config = self.config.copy()
|
||||
search_config.update(kwargs)
|
||||
|
||||
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search
|
||||
|
||||
# Recompute parameters
|
||||
recompute_neighbor_embeddings = kwargs.get("recompute_neighbor_embeddings", False)
|
||||
zmq_port = kwargs.get("zmq_port", 5556)
|
||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
passages_file = kwargs.get("passages_file", None)
|
||||
zmq_port = search_config.get("zmq_port", 5557)
|
||||
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
passages_file = search_config.get("passages_file", None)
|
||||
|
||||
if recompute_neighbor_embeddings:
|
||||
print(f"INFO: HNSW ZMQ mode enabled - ensuring embedding server is running")
|
||||
# For recompute mode, try to find the passages file automatically
|
||||
if self.is_pruned and not passages_file:
|
||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
|
||||
if potential_passages_file.exists():
|
||||
passages_file = str(potential_passages_file)
|
||||
print(f"INFO: Found passages file for recompute mode: {passages_file}")
|
||||
else:
|
||||
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
|
||||
|
||||
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
|
||||
if self.is_pruned:
|
||||
print(f"INFO: Index is pruned - starting embedding server for recompute")
|
||||
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file):
|
||||
print(f"WARNING: Failed to start HNSW embedding server, falling back to standard search")
|
||||
kwargs['recompute_neighbor_embeddings'] = False
|
||||
# CRITICAL: Check passages file exists - fail fast if not
|
||||
if not passages_file:
|
||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.")
|
||||
|
||||
# Check if server is already running first
|
||||
if _check_port(zmq_port):
|
||||
print(f"INFO: Embedding server already running on port {zmq_port}")
|
||||
else:
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
|
||||
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
||||
|
||||
# Give server extra time to fully initialize
|
||||
print(f"INFO: Waiting for embedding server to fully initialize...")
|
||||
time.sleep(3)
|
||||
|
||||
# Final verification
|
||||
if not _check_port(zmq_port):
|
||||
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
|
||||
else:
|
||||
print(f"INFO: Index has embeddings stored - no recompute needed")
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -299,23 +506,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
||||
|
||||
if recompute_neighbor_embeddings:
|
||||
# Use custom search with recompute
|
||||
# This would require implementing custom HNSW search logic
|
||||
# For now, we'll fall back to standard search
|
||||
print("WARNING: Recompute functionality for HNSW not yet implemented, using standard search")
|
||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
||||
else:
|
||||
# Standard FAISS search using SWIG API
|
||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
||||
# Use standard FAISS search - recompute is handled internally by FAISS
|
||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
||||
|
||||
return {"labels": labels, "distances": distances}
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
|
||||
batch_size = query.shape[0]
|
||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'embedding_server_manager'):
|
||||
|
||||
@@ -85,6 +85,7 @@ def create_hnsw_embedding_server(
|
||||
max_batch_size: int = 128,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
custom_max_length_param: Optional[int] = None,
|
||||
distance_metric: str = "mips",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
@@ -100,8 +101,11 @@ def create_hnsw_embedding_server(
|
||||
max_batch_size: Maximum batch size for processing
|
||||
model_name: Transformer model name
|
||||
custom_max_length_param: Custom max sequence length
|
||||
distance_metric: The distance metric to use
|
||||
"""
|
||||
print(f"Loading tokenizer for {model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
print(f"Tokenizer loaded successfully!")
|
||||
|
||||
# Device setup
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
@@ -122,7 +126,9 @@ def create_hnsw_embedding_server(
|
||||
|
||||
# Load model to the appropriate device
|
||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
print(f"Loading model {model_name}... (this may take a while if downloading)")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully!")
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
@@ -218,6 +224,7 @@ def create_hnsw_embedding_server(
|
||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process a batch of texts and return embeddings"""
|
||||
_is_e5_model = "e5" in model_name.lower()
|
||||
_is_bge_model = "bge" in model_name.lower()
|
||||
batch_size = len(texts_batch)
|
||||
|
||||
# E5 model preprocessing
|
||||
@@ -258,7 +265,9 @@ def create_hnsw_embedding_server(
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
|
||||
with pool_timer.timing():
|
||||
if not hasattr(out, 'last_hidden_state'):
|
||||
if _is_bge_model:
|
||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||
elif not hasattr(out, 'last_hidden_state'):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
else:
|
||||
@@ -275,7 +284,7 @@ def create_hnsw_embedding_server(
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
|
||||
final_embeddings = pooled_embeddings
|
||||
if _is_e5_model:
|
||||
if _is_e5_model or _is_bge_model:
|
||||
with norm_timer.timing():
|
||||
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
||||
|
||||
@@ -364,13 +373,14 @@ def create_hnsw_embedding_server(
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None or txtinfo["text"] == "":
|
||||
print(f"Warning: Passage with ID {nid} not found")
|
||||
missing_ids.append(nid)
|
||||
txt = ""
|
||||
else:
|
||||
txt = txtinfo["text"]
|
||||
try:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None or txtinfo["text"] == "":
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
||||
else:
|
||||
txt = txtinfo["text"]
|
||||
except (KeyError, IndexError):
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
||||
texts.append(txt)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
@@ -403,14 +413,14 @@ def create_hnsw_embedding_server(
|
||||
calc_timer = DeviceTimer("distance calculation", device)
|
||||
with calc_timer.timing():
|
||||
with torch.no_grad():
|
||||
if is_similarity_metric():
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
||||
query_np = query_tensor.cpu().numpy()
|
||||
distances = -np.dot(node_embeddings_np, query_np)
|
||||
else:
|
||||
if distance_metric == "l2":
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
|
||||
query_np = query_tensor.cpu().numpy().astype(np.float32)
|
||||
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
|
||||
else: # mips or cosine
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
||||
query_np = query_tensor.cpu().numpy()
|
||||
distances = -np.dot(node_embeddings_np, query_np)
|
||||
calc_timer.print_elapsed()
|
||||
|
||||
try:
|
||||
@@ -450,13 +460,14 @@ def create_hnsw_embedding_server(
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None or txtinfo["text"] == "":
|
||||
print(f"Warning: Passage with ID {nid} not found")
|
||||
missing_ids.append(nid)
|
||||
txt = ""
|
||||
else:
|
||||
txt = txtinfo["text"]
|
||||
try:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None or txtinfo["text"] == "":
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
||||
else:
|
||||
txt = txtinfo["text"]
|
||||
except (KeyError, IndexError):
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
||||
texts.append(txt)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
@@ -566,6 +577,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
|
||||
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -580,4 +592,5 @@ if __name__ == "__main__":
|
||||
max_batch_size=args.max_batch_size,
|
||||
model_name=args.model_name,
|
||||
custom_max_length_param=args.custom_max_length,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
17
packages/leann-core/src/leann/__init__.py
Normal file
17
packages/leann-core/src/leann/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file makes the 'leann' directory a Python package.
|
||||
|
||||
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult
|
||||
|
||||
# Import backends to ensure they are registered
|
||||
try:
|
||||
import leann_backend_hnsw
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']
|
||||
@@ -5,56 +5,104 @@ import numpy as np
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import openai # Import openai library
|
||||
import openai
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# --- Helper Functions for Embeddings ---
|
||||
|
||||
def _get_openai_client():
|
||||
"""Initializes and returns an OpenAI client, ensuring the API key is set."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
|
||||
return openai.OpenAI(api_key=api_key)
|
||||
|
||||
def _is_openai_model(model_name: str) -> bool:
|
||||
"""Checks if the model is likely an OpenAI embedding model."""
|
||||
# This is a simple check, can be improved with a more robust list.
|
||||
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
|
||||
|
||||
# 一个辅助函数,用于临时计算 embedding
|
||||
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
try:
|
||||
"""Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
|
||||
if _is_openai_model(model_name):
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
|
||||
client = _get_openai_client()
|
||||
response = client.embeddings.create(model=model_name, input=chunks)
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
else:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(model_name)
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using '{model_name}'...")
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
|
||||
embeddings = model.encode(chunks, show_progress_bar=True)
|
||||
return np.asarray(embeddings, dtype=np.float32)
|
||||
except ImportError:
|
||||
print("WARNING: sentence-transformers not installed. Falling back to random embeddings.")
|
||||
# 如果没有安装,则生成随机向量用于测试
|
||||
# TODO: 应该从一个固定的地方获取维度信息
|
||||
return np.random.rand(len(chunks), 768).astype(np.float32)
|
||||
|
||||
return np.asarray(embeddings, dtype=np.float32)
|
||||
|
||||
def _get_embedding_dimensions(model_name: str) -> int:
|
||||
"""Gets the embedding dimensions for a given model."""
|
||||
print(f"INFO: Calculating dimensions for model '{model_name}'...")
|
||||
if _is_openai_model(model_name):
|
||||
client = _get_openai_client()
|
||||
response = client.embeddings.create(model=model_name, input=["dummy text"])
|
||||
return len(response.data[0].embedding)
|
||||
else:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(model_name)
|
||||
dimension = model.get_sentence_embedding_dimension()
|
||||
if dimension is None:
|
||||
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
|
||||
return dimension
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Represents a single search result."""
|
||||
id: int
|
||||
score: float
|
||||
text: str
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# --- Core Classes ---
|
||||
|
||||
class LeannBuilder:
|
||||
"""
|
||||
负责构建 Leann 索引的上层 API。
|
||||
它协调 embedding 计算和后端索引构建。
|
||||
The builder is responsible for building the index, it will compute the embeddings and then build the index.
|
||||
It will also save the metadata of the index.
|
||||
"""
|
||||
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", **backend_kwargs):
|
||||
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
|
||||
self.backend_name = backend_name
|
||||
self.backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if self.backend_factory is None:
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||
|
||||
self.backend_factory = backend_factory
|
||||
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
# 简单的分块逻辑
|
||||
self.chunks.append({"text": text, "metadata": metadata or {}})
|
||||
|
||||
def build_index(self, index_path: str):
|
||||
if not self.chunks:
|
||||
raise ValueError("No chunks added. Use add_text() first.")
|
||||
|
||||
# 1. 计算 embedding (这是 leann-core 的职责)
|
||||
if self.dimensions is None:
|
||||
self.dimensions = _get_embedding_dimensions(self.embedding_model)
|
||||
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
|
||||
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
|
||||
|
||||
# 2. 创建 builder 实例并构建索引
|
||||
builder_instance = self.backend_factory.builder(**self.backend_kwargs)
|
||||
builder_instance.build(embeddings, index_path, **self.backend_kwargs)
|
||||
current_backend_kwargs = self.backend_kwargs.copy()
|
||||
current_backend_kwargs['dimensions'] = self.dimensions
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
|
||||
build_kwargs = current_backend_kwargs.copy()
|
||||
build_kwargs['chunks'] = self.chunks
|
||||
builder_instance.build(embeddings, index_path, **build_kwargs)
|
||||
|
||||
# 3. 保存 leann 特有的元数据(不包含向量)
|
||||
index_dir = Path(index_path).parent
|
||||
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json"
|
||||
|
||||
@@ -62,6 +110,8 @@ class LeannBuilder:
|
||||
"version": "0.1.0",
|
||||
"backend_name": self.backend_name,
|
||||
"embedding_model": self.embedding_model,
|
||||
"dimensions": self.dimensions,
|
||||
"backend_kwargs": self.backend_kwargs,
|
||||
"num_chunks": len(self.chunks),
|
||||
"chunks": self.chunks,
|
||||
}
|
||||
@@ -72,7 +122,8 @@ class LeannBuilder:
|
||||
|
||||
class LeannSearcher:
|
||||
"""
|
||||
负责加载索引并执行检索的上层 API。
|
||||
The searcher is responsible for loading the index and performing the search.
|
||||
It will also load the metadata of the index.
|
||||
"""
|
||||
def __init__(self, index_path: str, **backend_kwargs):
|
||||
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
|
||||
@@ -89,36 +140,39 @@ class LeannSearcher:
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
|
||||
|
||||
# 创建 searcher 实例
|
||||
self.backend_impl = backend_factory.searcher(index_path, **backend_kwargs)
|
||||
final_kwargs = self.meta_data.get("backend_kwargs", {})
|
||||
final_kwargs.update(backend_kwargs)
|
||||
if 'dimensions' not in final_kwargs:
|
||||
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
|
||||
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
|
||||
|
||||
def search(self, query: str, top_k: int = 5, **search_kwargs):
|
||||
query_embedding = _compute_embeddings([query], self.embedding_model)
|
||||
|
||||
# 委托给后端的 search 方法
|
||||
search_kwargs['embedding_model'] = self.embedding_model
|
||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
||||
|
||||
# 丰富返回结果,加入原始文本和元数据
|
||||
enriched_results = []
|
||||
for label, dist in zip(results['labels'][0], results['distances'][0]):
|
||||
if label < len(self.meta_data['chunks']):
|
||||
chunk_info = self.meta_data['chunks'][label]
|
||||
enriched_results.append({
|
||||
"id": label,
|
||||
"score": dist,
|
||||
"text": chunk_info['text'],
|
||||
"metadata": chunk_info['metadata']
|
||||
})
|
||||
enriched_results.append(SearchResult(
|
||||
id=label,
|
||||
score=dist,
|
||||
text=chunk_info['text'],
|
||||
metadata=chunk_info.get('metadata', {})
|
||||
))
|
||||
return enriched_results
|
||||
|
||||
|
||||
class LeannChat:
|
||||
"""
|
||||
封装了 Searcher 和 LLM 的对话式 RAG 接口。
|
||||
The chat is responsible for the conversation with the LLM.
|
||||
It will use the searcher to get the results and then use the LLM to generate the response.
|
||||
"""
|
||||
def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs):
|
||||
# 如果用户没有指定后端,尝试从索引元数据中读取
|
||||
if backend_name is None:
|
||||
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
|
||||
if not leann_meta_path.exists():
|
||||
@@ -129,15 +183,6 @@ class LeannChat:
|
||||
|
||||
self.searcher = LeannSearcher(index_path, **kwargs)
|
||||
self.llm_model = llm_model
|
||||
self.openai_client = None # Lazy load
|
||||
|
||||
def _get_openai_client(self):
|
||||
if self.openai_client is None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
||||
self.openai_client = openai.OpenAI(api_key=api_key)
|
||||
return self.openai_client
|
||||
|
||||
def ask(self, question: str, top_k=5, **kwargs):
|
||||
"""
|
||||
@@ -169,15 +214,13 @@ class LeannChat:
|
||||
"""
|
||||
|
||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
||||
context = "\n\n".join([r['text'] for r in results])
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
|
||||
# 2. 构建 Prompt
|
||||
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
||||
|
||||
# 3. 调用 LLM
|
||||
print(f"DEBUG: Calling LLM with prompt: {prompt}...")
|
||||
try:
|
||||
client = self._get_openai_client()
|
||||
client = _get_openai_client()
|
||||
response = client.chat.completions.create(
|
||||
model=self.llm_model,
|
||||
messages=[
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# packages/leann-core/src/leann/registry.py
|
||||
|
||||
# 全局的后端注册表字典
|
||||
BACKEND_REGISTRY = {}
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from leann.interface import LeannBackendFactoryInterface
|
||||
|
||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
||||
|
||||
def register_backend(name: str):
|
||||
"""一个用于注册新后端类的装饰器。"""
|
||||
"""A decorator to register a new backend class."""
|
||||
def decorator(cls):
|
||||
print(f"INFO: Registering backend '{name}'")
|
||||
BACKEND_REGISTRY[name] = cls
|
||||
|
||||
@@ -30,7 +30,8 @@ dependencies = [
|
||||
"llama-index>=0.12.44",
|
||||
"llama-index-readers-docling",
|
||||
"llama-index-node-parser-docling",
|
||||
"ipykernel==6.29.5"
|
||||
"ipykernel==6.29.5",
|
||||
"msgpack>=1.1.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
82
uv.lock
generated
82
uv.lock
generated
@@ -2070,6 +2070,7 @@ dependencies = [
|
||||
{ name = "llama-index" },
|
||||
{ name = "llama-index-node-parser-docling" },
|
||||
{ name = "llama-index-readers-docling" },
|
||||
{ name = "msgpack" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "ollama" },
|
||||
@@ -2109,6 +2110,7 @@ requires-dist = [
|
||||
{ name = "llama-index-node-parser-docling" },
|
||||
{ name = "llama-index-readers-docling" },
|
||||
{ name = "matplotlib", marker = "extra == 'dev'" },
|
||||
{ name = "msgpack", specifier = ">=1.1.1" },
|
||||
{ name = "numpy", specifier = ">=1.26.0" },
|
||||
{ name = "ollama" },
|
||||
{ name = "openai", specifier = ">=1.0.0" },
|
||||
@@ -2730,13 +2732,13 @@ name = "mlx-lm"
|
||||
version = "0.25.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2" },
|
||||
{ name = "mlx" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "transformers", extra = ["sentencepiece"] },
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform == 'darwin'" },
|
||||
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'darwin'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "transformers", extra = ["sentencepiece"], marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/34/d2b551d4519a9bafdfd29f76987dbcaaee370b974cfa81acfba782d6063f/mlx_lm-0.25.2.tar.gz", hash = "sha256:7d01baa66916aabd5be7345786acbfaf01d4e3f646759d7232e14aebfd4420a8", size = 146815 }
|
||||
wheels = [
|
||||
@@ -2786,6 +2788,54 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "msgpack"
|
||||
version = "1.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/45/b1/ea4f68038a18c77c9467400d166d74c4ffa536f34761f7983a104357e614/msgpack-1.1.1.tar.gz", hash = "sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd", size = 173555 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/33/52/f30da112c1dc92cf64f57d08a273ac771e7b29dea10b4b30369b2d7e8546/msgpack-1.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed", size = 81799 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/35/7bfc0def2f04ab4145f7f108e3563f9b4abae4ab0ed78a61f350518cc4d2/msgpack-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8", size = 78278 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/c5/df5d6c1c39856bc55f800bf82778fd4c11370667f9b9e9d51b2f5da88f20/msgpack-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2", size = 402805 },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/8e/0bb8c977efecfe6ea7116e2ed73a78a8d32a947f94d272586cf02a9757db/msgpack-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4", size = 408642 },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/a1/731d52c1aeec52006be6d1f8027c49fdc2cfc3ab7cbe7c28335b2910d7b6/msgpack-1.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0", size = 395143 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/92/b42911c52cda2ba67a6418ffa7d08969edf2e760b09015593c8a8a27a97d/msgpack-1.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26", size = 395986 },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/dc/8ae165337e70118d4dab651b8b562dd5066dd1e6dd57b038f32ebc3e2f07/msgpack-1.1.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75", size = 402682 },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/27/555851cb98dcbd6ce041df1eacb25ac30646575e9cd125681aa2f4b1b6f1/msgpack-1.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338", size = 406368 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/64/39a26add4ce16f24e99eabb9005e44c663db00e3fce17d4ae1ae9d61df99/msgpack-1.1.1-cp310-cp310-win32.whl", hash = "sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd", size = 65004 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/18/73dfa3e9d5d7450d39debde5b0d848139f7de23bd637a4506e36c9800fd6/msgpack-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8", size = 71548 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/83/97f24bf9848af23fe2ba04380388216defc49a8af6da0c28cc636d722502/msgpack-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558", size = 82728 },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/7f/2eaa388267a78401f6e182662b08a588ef4f3de6f0eab1ec09736a7aaa2b/msgpack-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d", size = 79279 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/46/31eb60f4452c96161e4dfd26dbca562b4ec68c72e4ad07d9566d7ea35e8a/msgpack-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0", size = 423859 },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/16/a20fa8c32825cc7ae8457fab45670c7a8996d7746ce80ce41cc51e3b2bd7/msgpack-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f", size = 429975 },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/ea/6c958e07692367feeb1a1594d35e22b62f7f476f3c568b002a5ea09d443d/msgpack-1.1.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704", size = 413528 },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/05/ac84063c5dae79722bda9f68b878dc31fc3059adb8633c79f1e82c2cd946/msgpack-1.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2", size = 413338 },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/e8/fe86b082c781d3e1c09ca0f4dacd457ede60a13119b6ce939efe2ea77b76/msgpack-1.1.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2", size = 422658 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/2b/bafc9924df52d8f3bb7c00d24e57be477f4d0f967c0a31ef5e2225e035c7/msgpack-1.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752", size = 427124 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/3b/1f717e17e53e0ed0b68fa59e9188f3f610c79d7151f0e52ff3cd8eb6b2dc/msgpack-1.1.1-cp311-cp311-win32.whl", hash = "sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295", size = 65016 },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/45/9d1780768d3b249accecc5a38c725eb1e203d44a191f7b7ff1941f7df60c/msgpack-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458", size = 72267 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/26/389b9c593eda2b8551b2e7126ad3a06af6f9b44274eb3a4f054d48ff7e47/msgpack-1.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238", size = 82359 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/65/7d1de38c8a22cf8b1551469159d4b6cf49be2126adc2482de50976084d78/msgpack-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157", size = 79172 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/bd/cacf208b64d9577a62c74b677e1ada005caa9b69a05a599889d6fc2ab20a/msgpack-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce", size = 425013 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/ec/fd869e2567cc9c01278a736cfd1697941ba0d4b81a43e0aa2e8d71dab208/msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a", size = 426905 },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/2a/35860f33229075bce803a5593d046d8b489d7ba2fc85701e714fc1aaf898/msgpack-1.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c", size = 407336 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/16/69ed8f3ada150bf92745fb4921bd621fd2cdf5a42e25eb50bcc57a5328f0/msgpack-1.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b", size = 409485 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/b6/0c398039e4c6d0b2e37c61d7e0e9d13439f91f780686deb8ee64ecf1ae71/msgpack-1.1.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef", size = 412182 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/d0/0cf4a6ecb9bc960d624c93effaeaae75cbf00b3bc4a54f35c8507273cda1/msgpack-1.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a", size = 419883 },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/83/9697c211720fa71a2dfb632cad6196a8af3abea56eece220fde4674dc44b/msgpack-1.1.1-cp312-cp312-win32.whl", hash = "sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c", size = 65406 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/23/0abb886e80eab08f5e8c485d6f13924028602829f63b8f5fa25a06636628/msgpack-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4", size = 72558 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/38/561f01cf3577430b59b340b51329803d3a5bf6a45864a55f4ef308ac11e3/msgpack-1.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0", size = 81677 },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/48/54a89579ea36b6ae0ee001cba8c61f776451fad3c9306cd80f5b5c55be87/msgpack-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9", size = 78603 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/60/daba2699b308e95ae792cdc2ef092a38eb5ee422f9d2fbd4101526d8a210/msgpack-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8", size = 420504 },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/22/2ebae7ae43cd8f2debc35c631172ddf14e2a87ffcc04cf43ff9df9fff0d3/msgpack-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a", size = 423749 },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/1b/54c08dd5452427e1179a40b4b607e37e2664bca1c790c60c442c8e972e47/msgpack-1.1.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac", size = 404458 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2e/60/6bb17e9ffb080616a51f09928fdd5cac1353c9becc6c4a8abd4e57269a16/msgpack-1.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b", size = 405976 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/97/88983e266572e8707c1f4b99c8fd04f9eb97b43f2db40e3172d87d8642db/msgpack-1.1.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7", size = 408607 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/66/36c78af2efaffcc15a5a61ae0df53a1d025f2680122e2a9eb8442fed3ae4/msgpack-1.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5", size = 424172 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/87/a75eb622b555708fe0427fab96056d39d4c9892b0c784b3a721088c7ee37/msgpack-1.1.1-cp313-cp313-win32.whl", hash = "sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323", size = 65347 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/91/7dc28d5e2a11a5ad804cf2b7f7a5fcb1eb5a4966d66a5d2b41aee6376543/msgpack-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69", size = 72341 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "msgspec"
|
||||
version = "0.19.0"
|
||||
@@ -3220,7 +3270,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.5.1.17"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386 },
|
||||
@@ -3231,7 +3281,7 @@ name = "nvidia-cufft-cu12"
|
||||
version = "11.3.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632 },
|
||||
@@ -3260,9 +3310,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.7.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cusparse-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790 },
|
||||
@@ -3274,7 +3324,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.5.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367 },
|
||||
@@ -5785,8 +5835,8 @@ wheels = [
|
||||
|
||||
[package.optional-dependencies]
|
||||
sentencepiece = [
|
||||
{ name = "protobuf" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5794,7 +5844,7 @@ name = "triton"
|
||||
version = "3.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "setuptools" },
|
||||
{ name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/a9/549e51e9b1b2c9b854fd761a1d23df0ba2fbc60bd0c13b489ffa518cfcb7/triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e", size = 155600257 },
|
||||
|
||||
Reference in New Issue
Block a user