Merge remote-tracking branch 'origin/main' into datastore-reproduce
This commit is contained in:
1
packages/__init__.py
Normal file
1
packages/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
packages/leann-backend-diskann/__init__.py
Normal file
1
packages/leann-backend-diskann/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# This file makes the directory a Python package
|
||||
@@ -15,6 +15,8 @@ import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
@@ -109,8 +111,6 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
if not os.path.exists(passages_file):
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
@@ -210,7 +210,6 @@ def create_embedding_server_thread(
|
||||
passages = load_passages_from_metadata(passages_file)
|
||||
else:
|
||||
# Try to find metadata file in same directory
|
||||
from pathlib import Path
|
||||
passages_dir = Path(passages_file).parent
|
||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||
if meta_files:
|
||||
|
||||
@@ -2,6 +2,33 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
endif()
|
||||
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
|
||||
# Configure msgpack-c - disable boost dependency
|
||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
@@ -77,17 +78,29 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self._convert_to_csr(index_file)
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
success = convert_hnsw_graph_to_csr(
|
||||
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
|
||||
str(index_file),
|
||||
str(csr_temp_file),
|
||||
prune_embeddings=self.is_recompute
|
||||
)
|
||||
|
||||
if success:
|
||||
import shutil
|
||||
print("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||
else:
|
||||
# Clean up and fail fast
|
||||
if csr_temp_file.exists():
|
||||
os.remove(csr_temp_file)
|
||||
raise RuntimeError("CSR conversion failed")
|
||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||
|
||||
class HNSWSearcher(BaseSearcher):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
@@ -99,7 +112,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
|
||||
self.is_compact, self.is_pruned = (
|
||||
self.meta.get('is_compact', True),
|
||||
self.meta.get('is_pruned', True)
|
||||
)
|
||||
|
||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||
if not index_file.exists():
|
||||
@@ -114,11 +130,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
|
||||
is_compact = self.meta.get('is_compact', True)
|
||||
is_pruned = self.meta.get('is_pruned', True)
|
||||
return is_compact, is_pruned
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||
from . import faiss
|
||||
|
||||
|
||||
@@ -1,4 +1,14 @@
|
||||
# packages/leann-core/src/leann/__init__.py
|
||||
import os
|
||||
import platform
|
||||
|
||||
# Fix OpenMP threading issues on macOS ARM64
|
||||
if platform.system() == "Darwin":
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
os.environ["KMP_BLOCKTIME"] = "0"
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This file contains the core API for the LEANN project, now definitively updated
|
||||
with the correct, original embedding logic from the user's reference code.
|
||||
@@ -11,6 +10,7 @@ from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
@@ -25,13 +25,22 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
raise RuntimeError(
|
||||
f"sentence-transformers not available. Install with: pip install sentence-transformers"
|
||||
) from e
|
||||
|
||||
|
||||
# Load model using sentence-transformers
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
|
||||
model = model.half()
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
# --- Core API Classes (Restored and Unchanged) ---
|
||||
@@ -181,5 +190,25 @@ class LeannChat:
|
||||
def ask(self, question: str, top_k=5, **kwargs):
|
||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
||||
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
f"{context}\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
||||
|
||||
def start_interactive(self):
|
||||
print("\nLeann Chat started (type 'quit' to exit)")
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
response = self.ask(user_input)
|
||||
print(f"Leann: {response}")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
@@ -7,6 +7,7 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -95,7 +96,57 @@ class HFChat(LLMInterface):
|
||||
}
|
||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||
results = self.pipeline(prompt, **params)
|
||||
return results[0]['generated_text']
|
||||
|
||||
# Handle different response formats from transformers
|
||||
if isinstance(results, list) and len(results) > 0:
|
||||
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
|
||||
else:
|
||||
generated_text = str(results)
|
||||
|
||||
# Extract only the newly generated portion by removing the original prompt
|
||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||
response = generated_text[len(prompt):].strip()
|
||||
else:
|
||||
# Fallback: return the full response if prompt removal fails
|
||||
response = str(generated_text)
|
||||
|
||||
return response
|
||||
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
|
||||
|
||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||
|
||||
try:
|
||||
import openai
|
||||
self.client = openai.OpenAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Default parameters for OpenAI
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
}
|
||||
|
||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error communicating with OpenAI: {e}")
|
||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||
|
||||
class SimulatedChat(LLMInterface):
|
||||
"""A simple simulated chat for testing and development."""
|
||||
@@ -127,9 +178,11 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
||||
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
|
||||
|
||||
if llm_type == "ollama":
|
||||
return OllamaChat(model=model, host=llm_config.get("host"))
|
||||
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
|
||||
elif llm_type == "hf":
|
||||
return HFChat(model_name=model)
|
||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||
elif llm_type == "openai":
|
||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "simulated":
|
||||
return SimulatedChat()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user