refactor: check if current emb_server has correct passages/embedder

This commit is contained in:
Andy Lee
2025-07-13 22:33:33 -07:00
parent 77ac013a74
commit 3b5a185e60
5 changed files with 915 additions and 229 deletions

View File

@@ -14,43 +14,51 @@ dotenv.load_dotenv()
# Default WeChat export directory # Default WeChat export directory
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct" DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], index_path: str = "wechat_history_index.leann", max_count: int = -1):
def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path],
index_path: str = "wechat_history_index.leann",
max_count: int = -1,
):
""" """
Create LEANN index from multiple WeChat export data sources. Create LEANN index from multiple WeChat export data sources.
Args: Args:
export_dirs: List of Path objects pointing to WeChat export directories export_dirs: List of Path objects pointing to WeChat export directories
index_path: Path to save the LEANN index index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process per export max_count: Maximum number of chat entries to process per export
""" """
print("Creating LEANN index from multiple WeChat export data sources...") print("Creating LEANN index from multiple WeChat export data sources...")
# Load documents using WeChatHistoryReader from history_data # Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader() reader = WeChatHistoryReader()
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print(f"--- Index directory not found, building new index ---")
all_documents = [] all_documents = []
total_processed = 0 total_processed = 0
# Process each WeChat export directory # Process each WeChat export directory
for i, export_dir in enumerate(export_dirs): for i, export_dir in enumerate(export_dirs):
print(f"\nProcessing WeChat export {i+1}/{len(export_dirs)}: {export_dir}") print(
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try: try:
documents = reader.load_data( documents = reader.load_data(
wechat_export_dir=str(export_dir), wechat_export_dir=str(export_dir),
max_count=max_count, max_count=max_count,
concatenate_messages=False # Disable concatenation - one message per document concatenate_messages=False, # Disable concatenation - one message per document
) )
if documents: if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}") print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents) all_documents.extend(documents)
total_processed += len(documents) total_processed += len(documents)
# Check if we've reached the max count # Check if we've reached the max count
if max_count > 0 and total_processed >= max_count: if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents") print(f"Reached max count of {max_count} documents")
@@ -60,16 +68,18 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
except Exception as e: except Exception as e:
print(f"Error processing {export_dir}: {e}") print(f"Error processing {export_dir}: {e}")
continue continue
if not all_documents: if not all_documents:
print("No documents loaded from any source. Exiting.") print("No documents loaded from any source. Exiting.")
return None return None
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports") print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
)
# Create text splitter with 256 chunk size # Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them # Convert Documents to text strings and chunk them
all_texts = [] all_texts = []
for doc in all_documents: for doc in all_documents:
@@ -77,43 +87,50 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
nodes = text_splitter.get_nodes_from_documents([doc]) nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.get_content()) all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents") print(
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
)
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="Qwen/Qwen3-Embedding-0.6B", embedding_model="Qwen/Qwen3-Embedding-0.6B",
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
is_compact=True, is_compact=True,
is_recompute=True, is_recompute=True,
num_threads=1 # Force single-threaded mode num_threads=1, # Force single-threaded mode
) )
print(f"Adding {len(all_texts)} chat chunks to index...") print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts: for chunk_text in all_texts:
builder.add_text(chunk_text) builder.add_text(chunk_text)
builder.build_index(index_path) builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!") print(f"\nLEANN index built at {index_path}!")
else: else:
print(f"--- Using existing index at {INDEX_DIR} ---") print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path return index_path
def create_leann_index(export_dir: str = None, index_path: str = "wechat_history_index.leann", max_count: int = 1000):
def create_leann_index(
export_dir: str = None,
index_path: str = "wechat_history_index.leann",
max_count: int = 1000,
):
""" """
Create LEANN index from WeChat chat history data. Create LEANN index from WeChat chat history data.
Args: Args:
export_dir: Path to the WeChat export directory (optional, uses default if None) export_dir: Path to the WeChat export directory (optional, uses default if None)
index_path: Path to save the LEANN index index_path: Path to save the LEANN index
@@ -121,34 +138,35 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history
""" """
print("Creating LEANN index from WeChat chat history data...") print("Creating LEANN index from WeChat chat history data...")
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print(f"\n[PHASE 1] Building Leann index...")
# Load documents using WeChatHistoryReader from history_data # Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader() reader = WeChatHistoryReader()
documents = reader.load_data( documents = reader.load_data(
wechat_export_dir=export_dir, wechat_export_dir=export_dir,
max_count=max_count, max_count=max_count,
concatenate_messages=False # Disable concatenation - one message per document concatenate_messages=False, # Disable concatenation - one message per document
) )
if not documents: if not documents:
print("No documents loaded. Exiting.") print("No documents loaded. Exiting.")
return None return None
print(f"Loaded {len(documents)} chat documents") print(f"Loaded {len(documents)} chat documents")
# Create text splitter with 256 chunk size # Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them # Convert Documents to text strings and chunk them
all_texts = [] all_texts = []
for doc in documents: for doc in documents:
@@ -156,54 +174,55 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history
nodes = text_splitter.get_nodes_from_documents([doc]) nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.get_content()) all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents") print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
is_compact=True, is_compact=True,
is_recompute=True, is_recompute=True,
num_threads=1 # Force single-threaded mode num_threads=1, # Force single-threaded mode
) )
print(f"Adding {len(all_texts)} chat chunks to index...") print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts: for chunk_text in all_texts:
builder.add_text(chunk_text) builder.add_text(chunk_text)
builder.build_index(index_path) builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!") print(f"\nLEANN index built at {index_path}!")
else: else:
print(f"--- Using existing index at {INDEX_DIR} ---") print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path return index_path
async def query_leann_index(index_path: str, query: str): async def query_leann_index(index_path: str, query: str):
""" """
Query the LEANN index. Query the LEANN index.
Args: Args:
index_path: Path to the LEANN index index_path: Path to the LEANN index
query: The query string query: The query string
""" """
print(f"\n[PHASE 2] Starting Leann chat session...") print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path) chat = LeannChat(index_path=index_path)
print(f"You: {query}") print(f"You: {query}")
chat_response = chat.ask( chat_response = chat.ask(
query, query,
top_k=5, top_k=5,
recompute_beighbor_embeddings=True, recompute_beighbor_embeddings=True,
complexity=32, complexity=32,
beam_width=1, beam_width=1,
@@ -212,52 +231,74 @@ async def query_leann_index(index_path: str, query: str):
"model": "gpt-4o", "model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
llm_kwargs={ llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
"temperature": 0.0,
"max_tokens": 1000
}
) )
print(f"Leann: {chat_response}") print(f"Leann: {chat_response}")
async def main(): async def main():
"""Main function with integrated WeChat export functionality.""" """Main function with integrated WeChat export functionality."""
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN WeChat History Reader - Create and query WeChat chat history index') parser = argparse.ArgumentParser(
parser.add_argument('--export-dir', type=str, default=DEFAULT_WECHAT_EXPORT_DIR, description="LEANN WeChat History Reader - Create and query WeChat chat history index"
help=f'Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})') )
parser.add_argument('--index-dir', type=str, default="./wechat_history_index_leann_test", parser.add_argument(
help='Directory to store the LEANN index (default: ./wechat_history_index_leann_test)') "--export-dir",
parser.add_argument('--max-entries', type=int, default=5000, type=str,
help='Maximum number of chat entries to process (default: 5000)') default=DEFAULT_WECHAT_EXPORT_DIR,
parser.add_argument('--query', type=str, default=None, help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
help='Single query to run (default: runs example queries)') )
parser.add_argument('--force-export', action='store_true', default=False, parser.add_argument(
help='Force re-export of WeChat data even if exports exist') "--index-dir",
type=str,
default="./wechat_history_index_leann_test",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=5000,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--force-export",
action="store_true",
default=False,
help="Force re-export of WeChat data even if exports exist",
)
args = parser.parse_args() args = parser.parse_args()
INDEX_DIR = Path(args.index_dir) INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann") INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
print(f"Using WeChat export directory: {args.export_dir}") print(f"Using WeChat export directory: {args.export_dir}")
print(f"Index directory: {INDEX_DIR}") print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}") print(f"Max entries: {args.max_entries}")
# Initialize WeChat reader with export capabilities # Initialize WeChat reader with export capabilities
from history_data.wechat_history import WeChatHistoryReader from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader() reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method # Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir) export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs: if not export_dirs:
print("Failed to find or export WeChat data. Exiting.") print("Failed to find or export WeChat data. Exiting.")
return return
# Create or load the LEANN index from all sources # Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_wechat_exports(export_dirs, INDEX_PATH, max_count=args.max_entries) index_path = create_leann_index_from_multiple_wechat_exports(
export_dirs, INDEX_PATH, max_count=args.max_entries
)
if index_path: if index_path:
if args.query: if args.query:
# Run single query # Run single query
@@ -267,10 +308,11 @@ async def main():
queries = [ queries = [
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?", "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
] ]
for query in queries: for query in queries:
print("\n" + "="*60) print("\n" + "=" * 60)
await query_leann_index(index_path, query) await query_leann_index(index_path, query)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -14,6 +14,7 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
import zmq import zmq
import numpy as np import numpy as np
import msgpack
from pathlib import Path from pathlib import Path
RED = "\033[91m" RED = "\033[91m"
@@ -26,6 +27,7 @@ class SimplePassageLoader:
""" """
def __init__(self, passages_data: Optional[Dict[str, Any]] = None): def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {} self.passages_data = passages_data or {}
self._meta_path = ''
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID""" """Get passage by ID"""
@@ -38,6 +40,9 @@ class SimplePassageLoader:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.passages_data) return len(self.passages_data)
def keys(self):
return self.passages_data.keys()
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
""" """
@@ -101,8 +106,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.label_map) return len(self.label_map)
def keys(self):
return self.label_map.keys()
return LazyPassageLoader(passage_manager, label_map) loader = LazyPassageLoader(passage_manager, label_map)
loader._meta_path = meta_file
return loader
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
""" """
@@ -353,6 +363,100 @@ def create_embedding_server_thread(
continue continue
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes") print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
# Handle control messages (MessagePack format)
try:
request_payload = msgpack.unpackb(message)
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
response = [current_meta_path]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
# Update the server's meta path and reload passages
new_meta_path = request_payload[1]
try:
print(f"INFO: Updating server meta path to: {new_meta_path}")
# Reload passages from the new meta file
passages = load_passages_from_metadata(new_meta_path)
# Store the meta path for future queries
passages._meta_path = new_meta_path
response = ["SUCCESS"]
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
except Exception as e:
print(f"ERROR: Failed to update meta path: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__QUERY_MODEL__":
# Return the current model being used by the server
response = [model_name]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
# Update the server's embedding model
new_model_name = request_payload[1]
try:
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
# Clean up old model to free memory
if not use_mlx:
print("INFO: Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
# Load new model
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
# Optimize new model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {new_model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# Now safely delete old model after new one is loaded
del old_model
del old_tokenizer
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
# Update model name
model_name = new_model_name
response = ["SUCCESS"]
print(f"INFO: Successfully updated model to: {new_model_name}")
except Exception as e:
print(f"ERROR: Failed to update model: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
except:
# Not a control message, continue with normal protobuf processing
pass
e2e_start = time.time() e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup") lookup_timer = DeviceTimer("text lookup")

View File

@@ -17,10 +17,12 @@ import msgpack
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, Union from typing import Dict, Any, Optional, Union
import sys
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
def is_similarity_metric(): def is_similarity_metric():
""" """
Check if the metric type is similarity-based (like inner product). Check if the metric type is similarity-based (like inner product).
@@ -28,22 +30,27 @@ def is_similarity_metric():
""" """
return True # 1 is METRIC_INNER_PRODUCT in FAISS return True # 1 is METRIC_INNER_PRODUCT in FAISS
# Function for E5-style average pooling # Function for E5-style average pooling
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
class SimplePassageLoader: class SimplePassageLoader:
""" """
Simple passage loader that replaces config.py dependencies Simple passage loader that replaces config.py dependencies
""" """
def __init__(self, passages_data: Optional[Dict[str, Any]] = None): def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {} self.passages_data = passages_data or {}
self._meta_path = ""
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID""" """Get passage by ID"""
str_id = str(passage_id) str_id = str(passage_id)
@@ -52,54 +59,57 @@ class SimplePassageLoader:
else: else:
# Return empty text for missing passages # Return empty text for missing passages
return {"text": ""} return {"text": ""}
def __len__(self) -> int: def __len__(self) -> int:
return len(self.passages_data) return len(self.passages_data)
def keys(self):
return self.passages_data.keys()
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
""" """
Load passages using metadata file with PassageManager for lazy loading Load passages using metadata file with PassageManager for lazy loading
""" """
# Load metadata to get passage sources # Load metadata to get passage sources
with open(meta_file, 'r') as f: with open(meta_file, "r") as f:
meta = json.load(f) meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports # Import PassageManager dynamically to avoid circular imports
import sys
import importlib.util
# Find the leann package directory relative to this file # Find the leann package directory relative to this file
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src" leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path)) sys.path.insert(0, str(leann_core_path))
try: try:
from leann.api import PassageManager from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
passage_manager = PassageManager(meta["passage_sources"])
finally: finally:
sys.path.pop(0) sys.path.pop(0)
# Load label map # Load label map
passages_dir = Path(meta_file).parent passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map" label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists(): if label_map_file.exists():
import pickle import pickle
with open(label_map_file, 'rb') as f:
with open(label_map_file, "rb") as f:
label_map = pickle.load(f) label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries") print(f"Loaded label map with {len(label_map)} entries")
else: else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}") raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages") print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader): class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map): def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager self.passage_manager = passage_manager
self.label_map = label_map self.label_map = label_map
# Initialize parent with empty data # Initialize parent with empty data
super().__init__({}) super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading""" """Get passage by ID with lazy loading"""
try: try:
@@ -118,12 +128,16 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
except Exception as e: except Exception as e:
print(f"DEBUG: Exception getting passage {passage_id}: {e}") print(f"DEBUG: Exception getting passage {passage_id}: {e}")
return {"text": ""} return {"text": ""}
def __len__(self) -> int: def __len__(self) -> int:
return len(self.label_map) return len(self.label_map)
def keys(self):
return self.label_map.keys()
return LazyPassageLoader(passage_manager, label_map) return LazyPassageLoader(passage_manager, label_map)
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None, passages_data: Optional[Dict[str, str]] = None,
@@ -139,7 +153,7 @@ def create_hnsw_embedding_server(
): ):
""" """
Create and start a ZMQ-based embedding server for HNSW backend. Create and start a ZMQ-based embedding server for HNSW backend.
Args: Args:
passages_file: Path to JSON file containing passage ID -> text mapping passages_file: Path to JSON file containing passage ID -> text mapping
passages_data: Direct passage data dict (alternative to passages_file) passages_data: Direct passage data dict (alternative to passages_file)
@@ -156,14 +170,14 @@ def create_hnsw_embedding_server(
print(f"Loading tokenizer for {model_name}...") print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!") print(f"Tokenizer loaded successfully!")
# Device setup # Device setup
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available() cuda_available = torch.cuda.is_available()
print(f"MPS available: {mps_available}") print(f"MPS available: {mps_available}")
print(f"CUDA available: {cuda_available}") print(f"CUDA available: {cuda_available}")
if cuda_available: if cuda_available:
device = torch.device("cuda") device = torch.device("cuda")
print("Using CUDA device") print("Using CUDA device")
@@ -173,7 +187,7 @@ def create_hnsw_embedding_server(
else: else:
device = torch.device("cpu") device = torch.device("cpu")
print("Using CPU device (no GPU acceleration available)") print("Using CPU device (no GPU acceleration available)")
# Load model to the appropriate device # Load model to the appropriate device
print(f"Starting HNSW server on port {zmq_port} with model {model_name}") print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)") print(f"Loading model {model_name}... (this may take a while if downloading)")
@@ -182,9 +196,10 @@ def create_hnsw_embedding_server(
# Check port availability # Check port availability
import socket import socket
def check_port(port): def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0 return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port): if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}") print(f"{RED}Port {zmq_port} is already in use{RESET}")
@@ -196,8 +211,14 @@ def create_hnsw_embedding_server(
model = torch.compile(model) model = torch.compile(model)
print(f"Using FP16 precision with model: {model_name}") print(f"Using FP16 precision with model: {model_name}")
elif use_int8: elif use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization") print(
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig "- Using TorchAO for Int8 dynamic activation and Int8 weight quantization"
)
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt8WeightConfig,
)
quantize_(model, Int8DynamicActivationInt8WeightConfig()) quantize_(model, Int8DynamicActivationInt8WeightConfig())
model = torch.compile(model) model = torch.compile(model)
model.eval() model.eval()
@@ -209,8 +230,10 @@ def create_hnsw_embedding_server(
print(f"Using provided passages data: {len(passages)} passages") print(f"Using provided passages data: {len(passages)} passages")
elif passages_file: elif passages_file:
# Check if it's a metadata file or a single passages file # Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'): if passages_file.endswith(".meta.json"):
passages = load_passages_from_metadata(passages_file) passages = load_passages_from_metadata(passages_file)
# Store the meta path for future reference
passages._meta_path = passages_file
else: else:
# Try to find metadata file in same directory # Try to find metadata file in same directory
passages_dir = Path(passages_file).parent passages_dir = Path(passages_file).parent
@@ -220,8 +243,12 @@ def create_hnsw_embedding_server(
passages = load_passages_from_metadata(str(meta_files[0])) passages = load_passages_from_metadata(str(meta_files[0]))
else: else:
# Fallback to original single file loading (will cause warnings) # Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)") print(
passages = SimplePassageLoader() # Use empty loader to avoid massive warnings "WARNING: No metadata file found, using single file loading (may cause missing passage warnings)"
)
passages = (
SimplePassageLoader()
) # Use empty loader to avoid massive warnings
else: else:
passages = SimplePassageLoader() passages = SimplePassageLoader()
print("No passages provided, using empty loader") print("No passages provided, using empty loader")
@@ -238,12 +265,13 @@ def create_hnsw_embedding_server(
class DeviceTimer: class DeviceTimer:
"""Device event-based timer for accurate timing.""" """Device event-based timer for accurate timing."""
def __init__(self, name="", device=device): def __init__(self, name="", device=device):
self.name = name self.name = name
self.device = device self.device = device
self.start_time = 0 self.start_time = 0
self.end_time = 0 self.end_time = 0
if cuda_available: if cuda_available:
self.start_event = torch.cuda.Event(enable_timing=True) self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True)
@@ -289,30 +317,31 @@ def create_hnsw_embedding_server(
_is_e5_model = "e5" in model_name.lower() _is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower() _is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch) batch_size = len(texts_batch)
# Validate no empty texts # Allow empty texts to pass through (remove validation)
for i, text in enumerate(texts_batch):
if not text or text.strip() == "":
raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}")
# E5 model preprocessing # E5 model preprocessing
if _is_e5_model: if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch] processed_texts_batch = [f"passage: {text}" for text in texts_batch]
else: else:
processed_texts_batch = texts_batch processed_texts_batch = texts_batch
# Set max length # Set max length
if _is_e5_model: if _is_e5_model:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512 current_max_length = (
custom_max_length_param if custom_max_length_param is not None else 512
)
else: else:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256 current_max_length = (
custom_max_length_param if custom_max_length_param is not None else 256
)
tokenize_timer = DeviceTimer("tokenization (batch)", device) tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device) to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device) embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("pooling (batch)", device) pool_timer = DeviceTimer("pooling (batch)", device)
norm_timer = DeviceTimer("normalization (batch)", device) norm_timer = DeviceTimer("normalization (batch)", device)
with tokenize_timer.timing(): with tokenize_timer.timing():
encoded_batch = tokenizer( encoded_batch = tokenizer(
processed_texts_batch, processed_texts_batch,
@@ -322,48 +351,71 @@ def create_hnsw_embedding_server(
return_tensors="pt", return_tensors="pt",
return_token_type_ids=False, return_token_type_ids=False,
) )
seq_length = encoded_batch["input_ids"].size(1) seq_length = encoded_batch["input_ids"].size(1)
with to_device_timer.timing(): with to_device_timer.timing():
enc = {k: v.to(device) for k, v in encoded_batch.items()} enc = {k: v.to(device) for k, v in encoded_batch.items()}
with torch.no_grad(): with torch.no_grad():
with embed_timer.timing(): with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"]) out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing(): with pool_timer.timing():
if _is_bge_model: if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0] pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, 'last_hidden_state'): elif not hasattr(out, "last_hidden_state"):
if isinstance(out, torch.Tensor) and len(out.shape) == 2: if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out pooled_embeddings = out
else: else:
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}") print(
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768) f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32) )
hidden_dim = getattr(
model.config, "hidden_size", 384 if _is_e5_model else 768
)
pooled_embeddings = torch.zeros(
(batch_size, hidden_dim),
device=device,
dtype=enc["input_ids"].dtype
if hasattr(enc["input_ids"], "dtype")
else torch.float32,
)
elif _is_e5_model: elif _is_e5_model:
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask']) pooled_embeddings = e5_average_pool(
out.last_hidden_state, enc["attention_mask"]
)
else: else:
hidden_states = out.last_hidden_state hidden_states = out.last_hidden_state
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float() mask_expanded = (
enc["attention_mask"]
.unsqueeze(-1)
.expand(hidden_states.size())
.float()
)
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings final_embeddings = pooled_embeddings
if _is_e5_model or _is_bge_model: if _is_e5_model or _is_bge_model:
with norm_timer.timing(): with norm_timer.timing():
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1) final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any(): if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! " print(
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}") f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}"
)
dim_size = final_embeddings.shape[-1] dim_size = final_embeddings.shape[-1]
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy() error_output = torch.zeros(
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}") (batch_size, dim_size), device="cpu", dtype=torch.float32
).numpy()
print(
f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}"
)
return error_output return error_output
return final_embeddings.cpu().numpy() return final_embeddings.cpu().numpy()
def client_warmup(zmq_port): def client_warmup(zmq_port):
@@ -371,7 +423,7 @@ def create_hnsw_embedding_server(
time.sleep(2) time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...") print(f"Performing client-side warmup with model {model_name}...")
sample_ids = ["1", "2", "3", "4", "5"] sample_ids = ["1", "2", "3", "4", "5"]
try: try:
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) socket = context.socket(zmq.REQ)
@@ -379,12 +431,12 @@ def create_hnsw_embedding_server(
socket.setsockopt(zmq.RCVTIMEO, 30000) socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000) socket.setsockopt(zmq.SNDTIMEO, 30000)
try: try:
ids_to_send = [int(x) for x in sample_ids] ids_to_send = [int(x) for x in sample_ids]
except ValueError: except ValueError:
ids_to_send = [] ids_to_send = []
if not ids_to_send: if not ids_to_send:
print("Skipping warmup send.") print("Skipping warmup send.")
return return
@@ -392,14 +444,18 @@ def create_hnsw_embedding_server(
request_bytes = msgpack.packb(request_payload) request_bytes = msgpack.packb(request_payload)
for i in range(3): for i in range(3):
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...") print(f"Sending warmup request {i + 1}/3 via ZMQ (MessagePack)...")
socket.send(request_bytes) socket.send(request_bytes)
response_bytes = socket.recv() response_bytes = socket.recv()
response_payload = msgpack.unpackb(response_bytes) response_payload = msgpack.unpackb(response_bytes)
dimensions = response_payload[0] dimensions = response_payload[0]
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0 embeddings_count = (
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings") dimensions[0] if dimensions and len(dimensions) > 0 else 0
)
print(
f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings"
)
time.sleep(0.1) time.sleep(0.1)
print("Client-side MessagePack ZMQ warmup complete") print("Client-side MessagePack ZMQ warmup complete")
@@ -410,6 +466,7 @@ def create_hnsw_embedding_server(
def zmq_server_thread(): def zmq_server_thread():
"""ZMQ server thread""" """ZMQ server thread"""
nonlocal passages, model, tokenizer, model_name
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REP) socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
@@ -428,94 +485,277 @@ def create_hnsw_embedding_server(
try: try:
request_payload = msgpack.unpackb(message_bytes) request_payload = msgpack.unpackb(message_bytes)
print(f"DEBUG: Raw request_payload: {request_payload}")
print(f"DEBUG: request_payload type: {type(request_payload)}")
if isinstance(request_payload, list):
print(f"DEBUG: request_payload length: {len(request_payload)}")
for i, item in enumerate(request_payload):
print(
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
)
# Handle control messages for meta path and model management
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
current_meta_path = (
getattr(passages, "_meta_path", "")
if hasattr(passages, "_meta_path")
else ""
)
response = [current_meta_path]
socket.send(msgpack.packb(response))
continue
elif (
request_payload[0] == "__UPDATE_META_PATH__"
and len(request_payload) >= 2
):
# Update the server's meta path and reload passages
new_meta_path = request_payload[1]
try:
print(
f"INFO: Updating server meta path to: {new_meta_path}"
)
# Reload passages from the new meta file
passages = load_passages_from_metadata(new_meta_path)
# Store the meta path for future queries
passages._meta_path = new_meta_path
response = ["SUCCESS"]
print(
f"INFO: Successfully updated meta path and reloaded {len(passages)} passages"
)
except Exception as e:
print(f"ERROR: Failed to update meta path: {e}")
response = ["FAILED", str(e)]
socket.send(msgpack.packb(response))
continue
elif request_payload[0] == "__QUERY_MODEL__":
# Return the current model being used by the server
response = [model_name]
socket.send(msgpack.packb(response))
continue
elif (
request_payload[0] == "__UPDATE_MODEL__"
and len(request_payload) >= 2
):
# Update the server's embedding model
new_model_name = request_payload[1]
try:
print(
f"INFO: Updating server model from {model_name} to: {new_model_name}"
)
# Clean up old model to free memory
print("INFO: Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(
new_model_name, use_fast=True
)
# Load new model
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name)
model.to(device)
model.eval()
# Now safely delete old model after new one is loaded
del old_model
del old_tokenizer
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
# Update model name
model_name = new_model_name
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
response = ["SUCCESS"]
print(
f"INFO: Successfully updated model to: {new_model_name}"
)
except Exception as e:
print(f"ERROR: Failed to update model: {e}")
response = ["FAILED", str(e)]
socket.send(msgpack.packb(response))
continue
# Handle distance calculation requests # Handle distance calculation requests
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list): if (
isinstance(request_payload, list)
and len(request_payload) == 2
and isinstance(request_payload[0], list)
and isinstance(request_payload[1], list)
):
node_ids = request_payload[0] node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32) query_vector = np.array(request_payload[1], dtype=np.float32)
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}") print("DEBUG: Distance calculation request received")
print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}")
print(f" Passages loaded: {len(passages)}")
# Get embeddings for node IDs # Get embeddings for node IDs
texts = [] texts = []
missing_ids = [] missing_ids = []
with lookup_timer.timing(): with lookup_timer.timing():
for nid in node_ids: for nid in node_ids:
print(f"DEBUG: Looking up passage ID {nid}") print(f"DEBUG: Looking up passage ID {nid}")
txtinfo = passages[nid] try:
if txtinfo is None or txtinfo["text"] == "": txtinfo = passages[nid]
raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text") if txtinfo is None:
txt = txtinfo["text"] print(
print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}") f"ERROR: Passage with ID {nid} returned None"
texts.append(txt) )
print(f"ERROR: txtinfo: {txtinfo}")
raise RuntimeError(
f"FATAL: Passage with ID {nid} returned None"
)
txt = txtinfo[
"text"
] # Allow empty text to pass through
print(
f"DEBUG: Found text for ID {nid}, length: {len(txt)}"
)
texts.append(txt)
except KeyError:
print(
f"ERROR: Passage ID {nid} not found in passages dict"
)
print(
f"ERROR: Available passage IDs: {list(passages.keys())[:10]}..."
)
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(
f"ERROR: Exception looking up passage ID {nid}: {e}"
)
raise
lookup_timer.print_elapsed() lookup_timer.print_elapsed()
# Process embeddings in chunks if needed # Process embeddings in chunks if needed
all_node_embeddings = [] all_node_embeddings = []
total_size = len(texts) total_size = len(texts)
if total_size > max_batch_size: if total_size > max_batch_size:
for i in range(0, total_size, max_batch_size): for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size) end_idx = min(i + max_batch_size, total_size)
chunk_texts = texts[i:end_idx] chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx] chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids) embeddings_chunk = process_batch(
chunk_texts, chunk_ids, missing_ids
)
all_node_embeddings.append(embeddings_chunk) all_node_embeddings.append(embeddings_chunk)
if cuda_available: if cuda_available:
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif device.type == "mps": elif device.type == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
node_embeddings = np.vstack(all_node_embeddings) node_embeddings = np.vstack(all_node_embeddings)
else: else:
node_embeddings = process_batch(texts, node_ids, missing_ids) node_embeddings = process_batch(
texts, node_ids, missing_ids
)
# Calculate distances # Calculate distances
query_tensor = torch.tensor(query_vector, device=device).float() query_tensor = torch.tensor(query_vector, device=device).float()
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float() node_embeddings_tensor = torch.tensor(
node_embeddings, device=device
).float()
calc_timer = DeviceTimer("distance calculation", device) calc_timer = DeviceTimer("distance calculation", device)
with calc_timer.timing(): with calc_timer.timing():
with torch.no_grad(): with torch.no_grad():
if distance_metric == "l2": if distance_metric == "l2":
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32) node_embeddings_np = (
query_np = query_tensor.cpu().numpy().astype(np.float32) node_embeddings_tensor.cpu()
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1) .numpy()
else: # mips or cosine .astype(np.float32)
node_embeddings_np = node_embeddings_tensor.cpu().numpy() )
query_np = (
query_tensor.cpu().numpy().astype(np.float32)
)
distances = np.sum(
np.square(
node_embeddings_np - query_np.reshape(1, -1)
),
axis=1,
)
else: # mips or cosine
node_embeddings_np = (
node_embeddings_tensor.cpu().numpy()
)
query_np = query_tensor.cpu().numpy() query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np) distances = -np.dot(node_embeddings_np, query_np)
calc_timer.print_elapsed() calc_timer.print_elapsed()
try: try:
response_payload = distances.flatten().tolist() response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True) response_bytes = msgpack.packb(
print(f"Sending distance response with {len(distances)} distances") [response_payload], use_single_float=True
)
print(
f"Sending distance response with {len(distances)} distances"
)
except Exception as pack_error: except Exception as pack_error:
print(f"Error packing MessagePack distance response: {pack_error}") print(
f"ERROR: Error packing MessagePack distance response: {pack_error}"
)
print(f"ERROR: distances shape: {distances.shape}")
print(f"ERROR: distances dtype: {distances.dtype}")
print(f"ERROR: distances content: {distances}")
print(f"ERROR: node_ids: {node_ids}")
print(f"ERROR: query_vector shape: {query_vector.shape}")
# Still return empty for now but with full error info
response_bytes = msgpack.packb([[]]) response_bytes = msgpack.packb([[]])
socket.send(response_bytes) socket.send(response_bytes)
if device.type == "cuda": if device.type == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
elif device.type == "mps": elif device.type == "mps":
torch.mps.synchronize() torch.mps.synchronize()
e2e_end = time.time() e2e_end = time.time()
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds") print(
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
)
continue continue
# Standard embedding request # Standard embedding request
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list): if (
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}") not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
print(
f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []])) socket.send(msgpack.packb([[], []]))
continue continue
node_ids = request_payload[0] node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings") print(f"Request for {len(node_ids)} node embeddings")
except Exception as unpack_error: except Exception as unpack_error:
print(f"Error unpacking MessagePack request: {unpack_error}") print(f"Error unpacking MessagePack request: {unpack_error}")
socket.send(msgpack.packb([[], []])) socket.send(msgpack.packb([[], []]))
@@ -529,11 +769,15 @@ def create_hnsw_embedding_server(
try: try:
txtinfo = passages[nid] txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "": if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") raise RuntimeError(
f"FATAL: Passage with ID {nid} not found - failing fast"
)
else: else:
txt = txtinfo["text"] txt = txtinfo["text"]
except (KeyError, IndexError): except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") raise RuntimeError(
f"FATAL: Passage with ID {nid} not found - failing fast"
)
texts.append(txt) texts.append(txt)
lookup_timer.print_elapsed() lookup_timer.print_elapsed()
@@ -542,27 +786,35 @@ def create_hnsw_embedding_server(
# Process in chunks # Process in chunks
total_size = len(texts) total_size = len(texts)
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}") print(
f"Total batch size: {total_size}, max_batch_size: {max_batch_size}"
)
all_embeddings = [] all_embeddings = []
if total_size > max_batch_size: if total_size > max_batch_size:
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}") print(
f"Splitting batch of size {total_size} into chunks of {max_batch_size}"
)
for i in range(0, total_size, max_batch_size): for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size) end_idx = min(i + max_batch_size, total_size)
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}") print(
f"Processing chunk {i // max_batch_size + 1}/{(total_size + max_batch_size - 1) // max_batch_size}: items {i} to {end_idx - 1}"
)
chunk_texts = texts[i:end_idx] chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx] chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids) embeddings_chunk = process_batch(
chunk_texts, chunk_ids, missing_ids
)
all_embeddings.append(embeddings_chunk) all_embeddings.append(embeddings_chunk)
if cuda_available: if cuda_available:
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif device.type == "mps": elif device.type == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
hidden = np.vstack(all_embeddings) hidden = np.vstack(all_embeddings)
print(f"Combined embeddings shape: {hidden.shape}") print(f"Combined embeddings shape: {hidden.shape}")
else: else:
@@ -571,22 +823,30 @@ def create_hnsw_embedding_server(
# Serialization and response # Serialization and response
ser_start = time.time() ser_start = time.time()
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}") print(
f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}"
)
if np.isnan(hidden).any() or np.isinf(hidden).any(): if np.isnan(hidden).any() or np.isinf(hidden).any():
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! " print(
f"Requested IDs (sample): {node_ids[:5]}...{RESET}") f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
f"Requested IDs (sample): {node_ids[:5]}...{RESET}"
)
assert False assert False
try: try:
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32) hidden_contiguous_f32 = np.ascontiguousarray(
hidden, dtype=np.float32
)
response_payload = [ response_payload = [
list(hidden_contiguous_f32.shape), list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist() hidden_contiguous_f32.flatten().tolist(),
] ]
response_bytes = msgpack.packb(response_payload, use_single_float=True) response_bytes = msgpack.packb(
response_payload, use_single_float=True
)
except Exception as pack_error: except Exception as pack_error:
print(f"Error packing MessagePack response: {pack_error}") print(f"Error packing MessagePack response: {pack_error}")
response_bytes = msgpack.packb([[], []]) response_bytes = msgpack.packb([[], []])
socket.send(response_bytes) socket.send(response_bytes)
ser_end = time.time() ser_end = time.time()
@@ -606,8 +866,9 @@ def create_hnsw_embedding_server(
except Exception as e: except Exception as e:
print(f"Error in ZMQ server loop: {e}") print(f"Error in ZMQ server loop: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
try: try:
socket.send(msgpack.packb([[], []])) socket.send(msgpack.packb([[], []]))
except: except:
pass pass
@@ -621,7 +882,7 @@ def create_hnsw_embedding_server(
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start() zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}") print(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive # Keep the main thread alive
try: try:
while True: while True:
@@ -634,17 +895,41 @@ def create_hnsw_embedding_server(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="HNSW Embedding service") parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping") parser.add_argument(
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings") "--passages-file",
type=str,
help="JSON file containing passage ID to text mapping",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Pickle file containing pre-computed embeddings",
)
parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False) parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False) parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting") parser.add_argument(
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", "--max-batch-size",
help="Embedding model name") type=int,
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length") default=128,
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use") help="Maximum batch size before splitting",
)
parser.add_argument(
"--model-name",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name",
)
parser.add_argument(
"--custom-max-length",
type=int,
default=None,
help="Override model's default max sequence length",
)
parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use"
)
args = parser.parse_args() args = parser.parse_args()
# Create and start the HNSW embedding server # Create and start the HNSW embedding server
@@ -659,4 +944,4 @@ if __name__ == "__main__":
model_name=args.model_name, model_name=args.model_name,
custom_max_length_param=args.custom_max_length, custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric, distance_metric=args.distance_metric,
) )

View File

@@ -1,23 +1,169 @@
import os
import threading import threading
import time import time
import atexit import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import zmq
import msgpack
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import select
def _check_port(port: int) -> bool: def _check_port(port: int) -> bool:
"""Check if a port is in use""" """Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0 return s.connect_ex(("localhost", port)) == 0
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
"""
Check if the existing server on the port is using the correct meta file.
Returns True if the server has the right meta path, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a special control message to query the server's meta path
control_request = ["__QUERY_META_PATH__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the meta path and if it matches
if isinstance(response, list) and len(response) > 0:
server_meta_path = response[0]
# Normalize paths for comparison
expected_path = Path(expected_meta_path).resolve()
server_path = Path(server_meta_path).resolve() if server_meta_path else None
return server_path == expected_path
return False
except Exception as e:
print(f"WARNING: Could not query server meta path on port {port}: {e}")
return False
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
"""
Send a control message to update the server's meta path.
Returns True if successful, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the meta path
control_request = ["__UPDATE_META_PATH__", new_meta_path]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except Exception as e:
print(f"ERROR: Could not update server meta path on port {port}: {e}")
return False
def _check_server_model(port: int, expected_model: str) -> bool:
"""
Check if the existing server on the port is using the correct embedding model.
Returns True if the server has the right model, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a special control message to query the server's model
control_request = ["__QUERY_MODEL__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the model name and if it matches
if isinstance(response, list) and len(response) > 0:
server_model = response[0]
return server_model == expected_model
return False
except Exception as e:
print(f"WARNING: Could not query server model on port {port}: {e}")
return False
def _update_server_model(port: int, new_model: str) -> bool:
"""
Send a control message to update the server's embedding model.
Returns True if successful, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the model
control_request = ["__UPDATE_MODEL__", new_model]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except Exception as e:
print(f"ERROR: Could not update server model on port {port}: {e}")
return False
class EmbeddingServerManager: class EmbeddingServerManager:
""" """
A generic manager for handling the lifecycle of a backend-specific embedding server process. A generic manager for handling the lifecycle of a backend-specific embedding server process.
""" """
def __init__(self, backend_module_name: str): def __init__(self, backend_module_name: str):
""" """
Initializes the manager for a specific backend. Initializes the manager for a specific backend.
@@ -44,21 +190,119 @@ class EmbeddingServerManager:
bool: True if the server is started successfully or already running, False otherwise. bool: True if the server is started successfully or already running, False otherwise.
""" """
if self.server_process and self.server_process.poll() is None: if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})") # Even if we have a running process, check if model/meta path match
return True if self.server_port is not None:
port_in_use = _check_port(self.server_port)
if port_in_use:
print(
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
)
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if not model_matches:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(self.server_port, model_name):
print(
"❌ Failed to update existing server model. Restarting server..."
)
self.stop_server()
# Continue to start new server below
else:
print(
f"✅ Successfully updated existing server model to: {model_name}"
)
# Also check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"✅ Existing server already using correct model: {model_name}"
)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
self.stop_server()
# Continue to start new server below
else:
# No port stored - restart
print("⚠️ No port information stored. Restarting server...")
self.stop_server()
# Continue to start new server below
if _check_port(port): if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external server is running.") # Port is in use, check if it's using the correct meta file and model
passages_file = kwargs.get("passages_file")
print(f"INFO: Port {port} is in use. Checking server compatibility...")
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if not model_matches:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(port, model_name):
raise RuntimeError(
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
else:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):
meta_matches = _check_server_meta_path(port, str(passages_file))
if not meta_matches:
print(
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
)
if not _update_server_meta_path(port, str(passages_file)):
raise RuntimeError(
"❌ Failed to update server meta path. This may cause data synchronization issues."
)
print(
f"✅ Successfully updated server meta path to: {passages_file}"
)
else:
print(
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
)
print(f"✅ Server on port {port} is compatible and ready to use.")
return True return True
print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...") print(
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
)
try: try:
command = [ command = [
sys.executable, sys.executable,
"-m", self.backend_module_name, "-m",
"--zmq-port", str(port), self.backend_module_name,
"--model-name", model_name "--zmq-port",
str(port),
"--model-name",
model_name,
] ]
# Add extra arguments for specific backends # Add extra arguments for specific backends
@@ -76,9 +320,9 @@ class EmbeddingServerManager:
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True, text=True,
encoding='utf-8', encoding="utf-8",
bufsize=1, # Line buffered bufsize=1, # Line buffered
universal_newlines=True universal_newlines=True,
) )
self.server_port = port self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}") print(f"INFO: Server process started with PID: {self.server_process.pid}")
@@ -86,17 +330,21 @@ class EmbeddingServerManager:
max_wait, wait_interval = 120, 0.5 max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)): for _ in range(int(max_wait / wait_interval)):
if _check_port(port): if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.") print("✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True) log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start() log_thread.start()
return True return True
if self.server_process.poll() is not None: if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.") print(
"❌ ERROR: Server process terminated unexpectedly during startup."
)
self._print_recent_output() self._print_recent_output()
return False return False
time.sleep(wait_interval) time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.") print(
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
)
self.stop_server() self.stop_server()
return False return False
@@ -110,8 +358,7 @@ class EmbeddingServerManager:
return return
try: try:
# Read any available output # Read any available output
import select
import sys
if select.select([self.server_process.stdout], [], [], 0)[0]: if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read() output = self.server_process.stdout.read()
if output: if output:
@@ -129,19 +376,25 @@ class EmbeddingServerManager:
line = self.server_process.stdout.readline() line = self.server_process.stdout.readline()
if not line: if not line:
break break
print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True) print(
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
)
except Exception as e: except Exception as e:
print(f"Log monitor error: {e}") print(f"Log monitor error: {e}")
def stop_server(self): def stop_server(self):
"""Stops the embedding server process if it's running.""" """Stops the embedding server process if it's running."""
if self.server_process and self.server_process.poll() is None: if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...") print(
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
)
self.server_process.terminate() self.server_process.terminate()
try: try:
self.server_process.wait(timeout=5) self.server_process.wait(timeout=5)
print("INFO: Server process terminated.") print("INFO: Server process terminated.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.") print(
"WARNING: Server process did not terminate gracefully, killing it."
)
self.server_process.kill() self.server_process.kill()
self.server_process = None self.server_process = None

View File

@@ -32,6 +32,8 @@ dependencies = [
"llama-index-node-parser-docling", "llama-index-node-parser-docling",
"ipykernel==6.29.5", "ipykernel==6.29.5",
"msgpack>=1.1.1", "msgpack>=1.1.1",
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
] ]
[project.optional-dependencies] [project.optional-dependencies]