Merge remote-tracking branch 'origin/main' into datastore-reproduce

This commit is contained in:
Andy Lee
2025-07-12 05:42:16 +00:00
25 changed files with 2053 additions and 88 deletions

View File

@@ -1,4 +1,14 @@
# packages/leann-core/src/leann/__init__.py
import os
import platform
# Fix OpenMP threading issues on macOS ARM64
if platform.system() == "Darwin":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0"
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
@@ -11,6 +10,7 @@ from pathlib import Path
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
import uuid
import torch
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
@@ -25,13 +25,22 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
raise RuntimeError(
f"sentence-transformers not available. Install with: pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
return embeddings
# --- Core API Classes (Restored and Unchanged) ---
@@ -181,5 +190,25 @@ class LeannChat:
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r.text for r in results])
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']:
break
if not user_input:
continue
response = self.ask(user_input)
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break

View File

@@ -7,6 +7,7 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -95,7 +96,57 @@ class HFChat(LLMInterface):
}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
return results[0]['generated_text']
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
def ask(self, prompt: str, **kwargs) -> str:
# Default parameters for OpenAI
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
}
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error communicating with OpenAI: {e}")
return f"Error: Could not get a response from OpenAI. Details: {e}"
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
@@ -127,9 +178,11 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(model=model, host=llm_config.get("host"))
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
elif llm_type == "hf":
return HFChat(model_name=model)
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else: