fix: resolve all ruff linting errors and add lint CI check

- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments
- Replace Chinese comments with English equivalents
- Fix unused imports with proper noqa annotations for intentional imports
- Fix bare except clauses with specific exception types
- Fix redefined variables and undefined names
- Add ruff noqa annotations for generated protobuf files
- Add lint and format check to GitHub Actions CI pipeline
This commit is contained in:
Andy Lee
2025-07-26 22:35:12 -07:00
parent 8537a6b17e
commit b3e9ee96fa
53 changed files with 5655 additions and 5220 deletions

View File

@@ -14,4 +14,4 @@ from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends()
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]

View File

@@ -4,27 +4,30 @@ with the correct, original embedding logic from the user's reference code.
"""
import json
import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
import logging
import pickle
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import numpy as np
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .interface import LeannBackendFactoryInterface
from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__)
def compute_embeddings(
chunks: List[str],
chunks: list[str],
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
port: Optional[int] = None,
port: int | None = None,
is_build=False,
) -> np.ndarray:
"""
@@ -61,9 +64,7 @@ def compute_embeddings(
)
def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
@@ -73,9 +74,9 @@ def compute_embeddings_via_server(
logger.info(
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
import zmq
import msgpack
import numpy as np
import zmq
# Connect to embedding server
context = zmq.Context()
@@ -104,11 +105,11 @@ class SearchResult:
id: str
score: float
text: str
metadata: Dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
class PassageManager:
def __init__(self, passage_sources: List[Dict[str, Any]]):
def __init__(self, passage_sources: list[dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
@@ -117,15 +118,15 @@ class PassageManager:
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"]
index_file = source["index_path"] # .idx file
# Fix path resolution for Colab and other environments
if not Path(index_file).is_absolute():
# If relative path, try to resolve it properly
index_file = str(Path(index_file).resolve())
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
@@ -135,11 +136,11 @@ class PassageManager:
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
def get_passage(self, passage_id: str) -> dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f:
with open(passage_file, encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
@@ -150,14 +151,12 @@ class LeannBuilder:
self,
backend_name: str,
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
dimensions: int | None = None,
embedding_mode: str = "sentence-transformers",
**backend_kwargs,
):
self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
backend_name
)
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory
@@ -165,9 +164,9 @@ class LeannBuilder:
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
self.chunks: list[dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(len(self.chunks)))
@@ -197,9 +196,7 @@ class LeannBuilder:
try:
from tqdm import tqdm
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
except ImportError:
chunk_iterator = self.chunks
@@ -229,9 +226,7 @@ class LeannBuilder:
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(
embeddings, string_ids, index_path, **current_backend_kwargs
)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
@@ -280,9 +275,7 @@ class LeannBuilder:
ids, embeddings = data
if not isinstance(embeddings, np.ndarray):
raise ValueError(
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
if len(ids) != embeddings.shape[0]:
raise ValueError(
@@ -294,9 +287,7 @@ class LeannBuilder:
if self.dimensions is None:
self.dimensions = embedding_dim
elif self.dimensions != embedding_dim:
raise ValueError(
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
logger.info(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
@@ -381,9 +372,7 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(
f"Index built successfully from precomputed embeddings: {index_path}"
)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
class LeannSearcher:
@@ -391,20 +380,16 @@ class LeannSearcher:
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists():
raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
with open(self.meta_path_str, encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
@@ -426,7 +411,7 @@ class LeannSearcher:
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
) -> list[SearchResult]:
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
@@ -453,7 +438,7 @@ class LeannSearcher:
zmq_port=zmq_port,
)
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
@@ -468,17 +453,15 @@ class LeannSearcher:
zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
time.time() - start_time
# logger.info(f" Search time: {search_time} seconds")
logger.info(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
zip(results["labels"][0], results["distances"][0], strict=False)
):
try:
passage_data = self.passage_manager.get_passage(string_id)
@@ -490,15 +473,15 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}),
)
)
# Color codes for better logging
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RESET = "\033[0m"
# Truncate text for display (first 100 chars)
display_text = passage_data['text']
display_text = passage_data["text"]
logger.info(
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
)
@@ -516,7 +499,7 @@ class LeannChat:
def __init__(
self,
index_path: str,
llm_config: Optional[Dict[str, Any]] = None,
llm_config: dict[str, Any] | None = None,
enable_warmup: bool = False,
**kwargs,
):
@@ -532,7 +515,7 @@ class LeannChat:
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[Dict[str, Any]] = None,
llm_kwargs: dict[str, Any] | None = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):

View File

@@ -4,11 +4,12 @@ This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import difflib
import logging
import os
import difflib
from abc import ABC, abstractmethod
from typing import Any
import torch
# Configure logging
@@ -16,10 +17,11 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models() -> List[str]:
def check_ollama_models() -> list[str]:
"""Check available Ollama models and return a list"""
try:
import requests
response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code == 200:
data = response.json()
@@ -31,51 +33,52 @@ def check_ollama_models() -> List[str]:
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import requests
import re
import requests
# Split model name and tag
if ':' in model_name:
base_model, requested_tag = model_name.split(':', 1)
if ":" in model_name:
base_model, requested_tag = model_name.split(":", 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if '<' in tag or '>' in tag:
if "<" in tag or ">" in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
@@ -83,76 +86,80 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
query_lower = query.lower()
suggestions = []
# 1. Exact matches first
exact_matches = [m for m in available_models if query_lower == m.lower()]
suggestions.extend(exact_matches)
# 2. Starts with query
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
starts_with = [
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
]
suggestions.extend(starts_with)
# 3. Contains query
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
suggestions.extend(contains)
# 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(':')[0].split('-')[0]
return model_name.split(":")[0].split("-")[0]
query_base = get_base_name(query_lower)
base_matches = [
m for m in available_models
m
for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions
]
suggestions.extend(base_matches)
# 5. Family/variant matching
model_families = {
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
'qwen': ['qwen', 'qwen2', 'qwen3'],
'gemma': ['gemma', 'gemma2'],
'phi': ['phi', 'phi2', 'phi3'],
'mistral': ['mistral', 'mixtral', 'openhermes'],
'dolphin': ['dolphin', 'openchat'],
'deepseek': ['deepseek', 'deepseek-coder']
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
"qwen": ["qwen", "qwen2", "qwen3"],
"gemma": ["gemma", "gemma2"],
"phi": ["phi", "phi2", "phi3"],
"mistral": ["mistral", "mixtral", "openhermes"],
"dolphin": ["dolphin", "openchat"],
"deepseek": ["deepseek", "deepseek-coder"],
}
query_family = None
for family, variants in model_families.items():
if any(variant in query_lower for variant in variants):
query_family = family
break
if query_family:
family_variants = model_families[query_family]
family_matches = [
m for m in available_models
m
for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
]
suggestions.extend(family_matches)
# 6. Use difflib for remaining fuzzy matches
remaining_models = [m for m in available_models if m not in suggestions]
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
suggestions.extend(difflib_matches)
return suggestions[:8] # Return top 8 suggestions
@@ -162,15 +169,13 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
# Get close matches using fuzzy matching
suggestions = difflib.get_close_matches(
invalid_model, available_models, n=3, cutoff=0.3
)
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
return suggestions
@@ -178,49 +183,50 @@ def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading"""
try:
from huggingface_hub import model_info
model_info(model_name)
return True
except Exception:
return False
def get_popular_hf_models() -> List[str]:
def get_popular_hf_models() -> list[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
# Get popular text-generation models, sorted by downloads
models = list_models(
filter="text-generation",
sort="downloads",
direction=-1,
limit=20 # Get top 20 most downloaded
limit=20, # Get top 20 most downloaded
)
# Extract model names and filter for chat/conversation models
model_names = []
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
for model in models:
model_name = model.id if hasattr(model, 'id') else str(model)
model_name = model.id if hasattr(model, "id") else str(model)
# Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name)
elif len(model_names) < 10: # Fill up with other popular models
model_names.append(model_name)
return model_names[:10] if model_names else _get_fallback_hf_models()
except Exception:
# Fallback to static list if API call fails
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> List[str]:
def _get_fallback_hf_models() -> list[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large",
"microsoft/DialoGPT-large",
"facebook/blenderbot-400M-distill",
"microsoft/phi-2",
"deepseek-ai/deepseek-llm-7b-chat",
@@ -228,44 +234,40 @@ def _get_fallback_hf_models() -> List[str]:
"facebook/blenderbot_small-90M",
"microsoft/phi-1_5",
"facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B"
"EleutherAI/gpt-neo-1.3B",
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
# HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models(
search=query,
filter="text-generation",
sort="downloads",
direction=-1,
limit=limit
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
)
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
# If direct search doesn't return enough results, try some variations
if len(model_names) < 3:
# Try searching for partial matches or common variations
variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
if base_query != query.lower():
variations.append(base_query)
# Try common model name patterns
if 'gpt' in query.lower():
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
elif 'llama' in query.lower():
variations.extend(['llama2', 'alpaca', 'vicuna'])
elif 'bert' in query.lower():
variations.extend(['roberta', 'distilbert', 'albert'])
if "gpt" in query.lower():
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
elif "llama" in query.lower():
variations.extend(["llama2", "alpaca", "vicuna"])
elif "bert" in query.lower():
variations.extend(["roberta", "distilbert", "albert"])
# Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
try:
@@ -274,13 +276,15 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
filter="text-generation",
sort="downloads",
direction=-1,
limit=3
limit=3,
)
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
var_names = [
model.id if hasattr(model, "id") else str(model) for model in var_models
]
model_names.extend(var_names)
except:
except Exception:
continue
# Remove duplicates while preserving order
seen = set()
unique_models = []
@@ -288,65 +292,67 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
if model not in seen:
seen.add(model)
unique_models.append(model)
return unique_models[:limit]
except Exception:
# If search fails, return empty list
return []
def search_hf_models(query: str, limit: int = 10) -> List[str]:
def search_hf_models(query: str, limit: int = 10) -> list[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models()
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += f"\n\nTo install the requested model:\n"
error_msg += "\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(':')[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
base_model = model_name.split(":")[0]
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
error_msg += (
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
)
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
@@ -357,23 +363,25 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
error_msg += (
f"\n ollama pull {available_tags[0]} # Install recommended variant"
)
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg
elif llm_type == "hf":
# For HF models, we can do a quick existence check
if not check_hf_model_exists(model_name):
# Use HF Hub's native fuzzy search directly
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
if search_suggestions:
error_msg += "\n\nDid you mean one of these?\n"
@@ -385,10 +393,10 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
error_msg += "\n\nPopular chat models:\n"
for i, model in enumerate(popular_models[:5], 1):
error_msg += f" {i}. {model}\n"
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
return error_msg
return None # Model is valid or we can't check
@@ -451,28 +459,27 @@ class OllamaChat(LLMInterface):
# Check if the Ollama server is responsive
if host:
requests.get(host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama")
if model_error:
raise ValueError(model_error)
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
import requests
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
@@ -482,7 +489,7 @@ class OllamaChat(LLMInterface):
}
logger.debug(f"Sending request to Ollama: {payload}")
try:
logger.info(f"Sending request to Ollama and waiting for response...")
logger.info("Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
@@ -506,15 +513,15 @@ class HFChat(LLMInterface):
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
raise ValueError(model_error)
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
@@ -537,36 +544,34 @@ class HFChat(LLMInterface):
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True
trust_remote_code=True,
)
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str:
print('kwargs in HF: ', kwargs)
print("kwargs in HF: ", kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
messages, tokenize=False, add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
@@ -577,13 +582,9 @@ class HFChat(LLMInterface):
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -597,32 +598,29 @@ class HFChat(LLMInterface):
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generation_config
)
outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
@@ -649,11 +647,7 @@ class OpenAIChat(LLMInterface):
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{
k: v
for k, v in kwargs.items()
if k not in ["max_tokens", "temperature"]
},
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
}
logger.info(f"Sending request to OpenAI with model {self.model}")
@@ -675,7 +669,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.

View File

@@ -5,12 +5,14 @@ from pathlib import Path
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannSearcher, LeannChat
from .api import LeannBuilder, LeannChat, LeannSearcher
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
"""Extract text from PDF using PyMuPDF for better quality."""
try:
import fitz # PyMuPDF
doc = fitz.open(file_path)
text = ""
for page in doc:
@@ -21,10 +23,12 @@ def extract_pdf_text_with_pymupdf(file_path: str) -> str:
# Fallback to default reader
return None
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
"""Extract text from PDF using pdfplumber for better quality."""
try:
import pdfplumber
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
@@ -72,18 +76,12 @@ Examples:
# Build command
build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument(
"--docs", type=str, required=True, help="Documents directory"
)
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
build_parser.add_argument(
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
)
build_parser.add_argument(
"--embedding-model", type=str, default="facebook/contriever"
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild"
)
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1)
@@ -129,7 +127,7 @@ Examples:
)
# List command
list_parser = subparsers.add_parser("list", help="List all indexes")
subparsers.add_parser("list", help="List all indexes")
return parser
@@ -137,17 +135,13 @@ Examples:
print("Stored LEANN indexes:")
if not self.indexes_dir.exists():
print(
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
return
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
if not index_dirs:
print(
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
return
print(f"Found {len(index_dirs)} indexes:")
@@ -157,15 +151,15 @@ Examples:
print(f" {i}. {index_name} [{status}]")
if self.index_exists(index_name):
meta_file = index_dir / "documents.leann.meta.json"
size_mb = sum(
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
) / (1024 * 1024)
index_dir / "documents.leann.meta.json"
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
1024 * 1024
)
print(f" Size: {size_mb:.1f} MB")
if index_dirs:
example_name = index_dirs[0].name
print(f"\nUsage:")
print("\nUsage:")
print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
@@ -175,19 +169,20 @@ Examples:
# Try to use better PDF parsers first
documents = []
docs_path = Path(docs_dir)
for file_path in docs_path.rglob("*.pdf"):
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:

View File

@@ -4,11 +4,12 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import numpy as np
import torch
from typing import List, Dict, Any
import logging
import os
from typing import Any
import numpy as np
import torch
# Set up logger with proper level
logger = logging.getLogger(__name__)
@@ -17,11 +18,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
_model_cache: dict[str, Any] = {}
def compute_embeddings(
texts: List[str],
texts: list[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
@@ -59,7 +60,7 @@ def compute_embeddings(
def compute_embeddings_sentence_transformers(
texts: List[str],
texts: list[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
@@ -114,9 +115,7 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
else:
logger.info(
f"Loading and caching optimized SentenceTransformer model: {model_name}"
)
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}")
@@ -134,9 +133,7 @@ def compute_embeddings_sentence_transformers(
if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError:
logger.warning(
"Some MPS optimizations not available in this PyTorch version"
)
logger.warning("Some MPS optimizations not available in this PyTorch version")
elif device == "cpu":
# TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4))
@@ -226,25 +223,22 @@ def compute_embeddings_sentence_transformers(
device=device,
)
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import openai
import os
import openai
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
@@ -294,16 +288,12 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
print(f"len of embeddings: {len(embeddings)}")
return embeddings
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:

View File

@@ -1,12 +1,12 @@
import time
import atexit
import logging
import os
import socket
import subprocess
import sys
import os
import logging
import time
from pathlib import Path
from typing import Optional
import psutil
# Set up logging based on environment variable
@@ -33,7 +33,7 @@ def _get_available_port(start_port: int = 5557) -> int:
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port+100}")
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
def _check_port(port: int) -> bool:
@@ -182,8 +182,8 @@ class EmbeddingServerManager:
e.g., "leann_backend_diskann.embedding_server"
"""
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
self.server_process: subprocess.Popen | None = None
self.server_port: int | None = None
self._atexit_registered = False
def start_server(
@@ -234,10 +234,10 @@ class EmbeddingServerManager:
return False, port
logger.info(f"Starting server on port {actual_port} for Colab environment")
# Use a simpler startup strategy for Colab
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
@@ -246,26 +246,16 @@ class EmbeddingServerManager:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server."""
if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
if not (self.server_process and self.server_process.poll() is None and self.server_port):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
return True
logger.info(
"Existing server process is incompatible. Should start a new server."
)
logger.info("Existing server process is incompatible. Should start a new server.")
return False
def _start_new_server(
@@ -400,7 +390,7 @@ class EmbeddingServerManager:
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
@@ -409,7 +399,7 @@ class EmbeddingServerManager:
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error(f"Colab server terminated during startup.")
logger.error("Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port

View File

@@ -1,15 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Literal
import numpy as np
from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes"""
@abstractmethod
def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
"""Build index
Args:
@@ -35,9 +34,7 @@ class LeannBackendSearcherInterface(ABC):
pass
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
"""Ensure server is running"""
pass
@@ -51,9 +48,9 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: int | None = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Search for nearest neighbors
Args:
@@ -77,7 +74,7 @@ class LeannBackendSearcherInterface(ABC):
self,
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
zmq_port: int | None = None,
) -> np.ndarray:
"""Compute embedding for a query string

View File

@@ -1,13 +1,13 @@
# packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str):
@@ -31,13 +31,11 @@ def autodiscover_backends():
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(
discovered_backends
): # sort for deterministic loading
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
except ImportError:
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")

View File

@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, Literal, Optional
from typing import Any, Literal
import numpy as np
@@ -38,9 +38,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print(
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
@@ -48,26 +46,22 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
backend_module_name=backend_module_name,
)
def _load_meta(self) -> Dict[str, Any]:
def _load_meta(self) -> dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, "r", encoding="utf-8") as f:
with open(meta_path, encoding="utf-8") as f:
return json.load(f)
def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> int:
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
"""
if not self.embedding_model:
raise ValueError(
"Cannot use recompute mode without 'embedding_model' in meta.json."
)
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
@@ -78,9 +72,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
return actual_port
@@ -109,9 +101,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
# on that port?
# Ensure we have a server with passages_file for compatibility
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
# Convert to absolute path to ensure server can find it
zmq_port = self._ensure_server_running(
str(passages_source_file.resolve()), zmq_port
@@ -132,8 +122,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""
import zmq
import msgpack
import zmq
try:
context = zmq.Context()
@@ -172,9 +162,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: int | None = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.