Merge remote-tracking branch 'origin/main' into datastore-reproduce

This commit is contained in:
Andy Lee
2025-07-12 05:42:16 +00:00
25 changed files with 2053 additions and 88 deletions

1
packages/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@
# This file makes the directory a Python package

View File

@@ -15,6 +15,8 @@ import os
from contextlib import contextmanager
import zmq
import numpy as np
from pathlib import Path
import pickle
RED = "\033[91m"
RESET = "\033[0m"
@@ -109,8 +111,6 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
from pathlib import Path
import pickle
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
@@ -210,7 +210,6 @@ def create_embedding_server_thread(
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
from pathlib import Path
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:

View File

@@ -2,6 +2,33 @@
cmake_minimum_required(VERSION 3.24)
project(leann_backend_hnsw_wrapper)
# Set OpenMP path for macOS
if(APPLE)
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
endif()
# Build ZeroMQ from source
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
add_subdirectory(third_party/libzmq)
# Add cppzmq headers
include_directories(third_party/cppzmq)
# Configure msgpack-c - disable boost dependency
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)

View File

@@ -4,6 +4,7 @@ import json
from pathlib import Path
from typing import Dict, Any, List
import pickle
import shutil
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -77,17 +78,29 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self._convert_to_csr(index_file)
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
str(index_file),
str(csr_temp_file),
prune_embeddings=self.is_recompute
)
if success:
import shutil
print("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed")
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
@@ -99,7 +112,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
self.is_compact, self.is_pruned = (
self.meta.get('is_compact', True),
self.meta.get('is_pruned', True)
)
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
@@ -114,11 +130,6 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
is_compact = self.meta.get('is_compact', True)
is_pruned = self.meta.get('is_pruned', True)
return is_compact, is_pruned
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
from . import faiss

View File

@@ -1,4 +1,14 @@
# packages/leann-core/src/leann/__init__.py
import os
import platform
# Fix OpenMP threading issues on macOS ARM64
if platform.system() == "Darwin":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0"
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
@@ -11,6 +10,7 @@ from pathlib import Path
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
import uuid
import torch
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
@@ -25,13 +25,22 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
raise RuntimeError(
f"sentence-transformers not available. Install with: pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
return embeddings
# --- Core API Classes (Restored and Unchanged) ---
@@ -181,5 +190,25 @@ class LeannChat:
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r.text for r in results])
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']:
break
if not user_input:
continue
response = self.ask(user_input)
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break

View File

@@ -7,6 +7,7 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -95,7 +96,57 @@ class HFChat(LLMInterface):
}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
return results[0]['generated_text']
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
def ask(self, prompt: str, **kwargs) -> str:
# Default parameters for OpenAI
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
}
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
return f"Error: Could not get a response from OpenAI. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
@@ -127,9 +178,11 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(model=model, host=llm_config.get("host"))
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
elif llm_type == "hf":
return HFChat(model_name=model)
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else: