- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments - Replace Chinese comments with English equivalents - Fix unused imports with proper noqa annotations for intentional imports - Fix bare except clauses with specific exception types - Fix redefined variables and undefined names - Add ruff noqa annotations for generated protobuf files - Add lint and format check to GitHub Actions CI pipeline
156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Document search demo with recompute mode
|
||
"""
|
||
|
||
import shutil
|
||
import time
|
||
from pathlib import Path
|
||
|
||
# Import backend packages to trigger plugin registration
|
||
try:
|
||
import leann_backend_diskann # noqa: F401
|
||
import leann_backend_hnsw # noqa: F401
|
||
|
||
print("INFO: Backend packages imported successfully.")
|
||
except ImportError as e:
|
||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||
|
||
# Import upper-level API from leann-core
|
||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||
|
||
|
||
def load_sample_documents():
|
||
"""Create sample documents for demonstration"""
|
||
docs = [
|
||
{
|
||
"title": "Intro to Python",
|
||
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||
},
|
||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
||
{
|
||
"title": "Data Structures",
|
||
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||
},
|
||
]
|
||
return docs
|
||
|
||
|
||
def main():
|
||
print("==========================================================")
|
||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||
print("==========================================================")
|
||
|
||
INDEX_DIR = Path("./test_indices")
|
||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||
BACKEND_TO_TEST = "diskann"
|
||
|
||
if INDEX_DIR.exists():
|
||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||
shutil.rmtree(INDEX_DIR)
|
||
|
||
# --- 1. Build index ---
|
||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||
|
||
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
|
||
|
||
documents = load_sample_documents()
|
||
print(f"Loaded {len(documents)} sample documents.")
|
||
for doc in documents:
|
||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||
|
||
builder.build_index(INDEX_PATH)
|
||
print("\nIndex built!")
|
||
|
||
# --- 2. Basic search demo ---
|
||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||
|
||
query = "What is machine learning?"
|
||
print(f"\nQuery: '{query}'")
|
||
|
||
print("\n--- Basic search mode (PQ computation) ---")
|
||
start_time = time.time()
|
||
results = searcher.search(query, top_k=2)
|
||
basic_time = time.time() - start_time
|
||
|
||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||
print(">>> Basic search results <<<")
|
||
for i, res in enumerate(results, 1):
|
||
print(
|
||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||
)
|
||
|
||
# --- 3. Recompute search demo ---
|
||
print("\n[PHASE 3] Recompute search using embedding server...")
|
||
|
||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||
|
||
# Configure recompute parameters
|
||
recompute_params = {
|
||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||
"skip_search_reorder": True, # Skip search reordering
|
||
"dedup_node_dis": True, # Enable node distance deduplication
|
||
"prune_ratio": 0.1, # Pruning ratio 10%
|
||
"batch_recompute": False, # Don't use batch recomputation
|
||
"global_pruning": False, # Don't use global pruning
|
||
"zmq_port": 5555, # ZMQ port
|
||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||
}
|
||
|
||
print("Recompute parameter configuration:")
|
||
for key, value in recompute_params.items():
|
||
print(f" {key}: {value}")
|
||
|
||
print("\n🔄 Executing Recompute search...")
|
||
try:
|
||
start_time = time.time()
|
||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||
recompute_time = time.time() - start_time
|
||
|
||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||
print(">>> Recompute search results <<<")
|
||
for i, res in enumerate(recompute_results, 1):
|
||
print(
|
||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||
)
|
||
|
||
# Compare results
|
||
print("\n--- Result comparison ---")
|
||
print(f"Basic search time: {basic_time:.3f} seconds")
|
||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||
|
||
print("\nBasic search vs Recompute results:")
|
||
for i in range(min(len(results), len(recompute_results))):
|
||
basic_score = results[i].score
|
||
recompute_score = recompute_results[i].score
|
||
score_diff = abs(basic_score - recompute_score)
|
||
print(
|
||
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
|
||
)
|
||
|
||
if recompute_time > basic_time:
|
||
print("✅ Recompute mode working correctly (more accurate but slower)")
|
||
else:
|
||
print("i️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||
|
||
except Exception as e:
|
||
print(f"❌ Recompute search failed: {e}")
|
||
print("This usually indicates an embedding server connection issue")
|
||
|
||
# --- 4. Chat demo ---
|
||
print("\n[PHASE 4] Starting chat session...")
|
||
chat = LeannChat(index_path=INDEX_PATH)
|
||
chat_response = chat.ask(query)
|
||
print(f"You: {query}")
|
||
print(f"Leann: {chat_response}")
|
||
|
||
print("\n==========================================================")
|
||
print("✅ Demo finished successfully!")
|
||
print("==========================================================")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|