diff --git a/examples/wechat_history_reader_leann.py b/examples/wechat_history_reader_leann.py index f46a1c2..f92dbb7 100644 --- a/examples/wechat_history_reader_leann.py +++ b/examples/wechat_history_reader_leann.py @@ -14,43 +14,51 @@ dotenv.load_dotenv() # Default WeChat export directory DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct" -def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], index_path: str = "wechat_history_index.leann", max_count: int = -1): + +def create_leann_index_from_multiple_wechat_exports( + export_dirs: List[Path], + index_path: str = "wechat_history_index.leann", + max_count: int = -1, +): """ Create LEANN index from multiple WeChat export data sources. - + Args: export_dirs: List of Path objects pointing to WeChat export directories index_path: Path to save the LEANN index max_count: Maximum number of chat entries to process per export """ print("Creating LEANN index from multiple WeChat export data sources...") - + # Load documents using WeChatHistoryReader from history_data from history_data.wechat_history import WeChatHistoryReader + reader = WeChatHistoryReader() - + INDEX_DIR = Path(index_path).parent - + if not INDEX_DIR.exists(): print(f"--- Index directory not found, building new index ---") all_documents = [] total_processed = 0 - + # Process each WeChat export directory for i, export_dir in enumerate(export_dirs): - print(f"\nProcessing WeChat export {i+1}/{len(export_dirs)}: {export_dir}") - + print( + f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}" + ) + try: documents = reader.load_data( wechat_export_dir=str(export_dir), max_count=max_count, - concatenate_messages=False # Disable concatenation - one message per document + concatenate_messages=False, # Disable concatenation - one message per document ) if documents: print(f"Loaded {len(documents)} chat documents from {export_dir}") all_documents.extend(documents) total_processed += len(documents) - + # Check if we've reached the max count if max_count > 0 and total_processed >= max_count: print(f"Reached max count of {max_count} documents") @@ -60,16 +68,18 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind except Exception as e: print(f"Error processing {export_dir}: {e}") continue - + if not all_documents: print("No documents loaded from any source. Exiting.") return None - - print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports") - + + print( + f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports" + ) + # Create text splitter with 256 chunk size text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) - + # Convert Documents to text strings and chunk them all_texts = [] for doc in all_documents: @@ -77,43 +87,50 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind nodes = text_splitter.get_nodes_from_documents([doc]) for node in nodes: all_texts.append(node.get_content()) - - print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents") - + + print( + f"Created {len(all_texts)} text chunks from {len(all_documents)} documents" + ) + # Create LEANN index directory print(f"--- Index directory not found, building new index ---") INDEX_DIR.mkdir(exist_ok=True) print(f"--- Building new LEANN index ---") - + print(f"\n[PHASE 1] Building Leann index...") # Use HNSW backend for better macOS compatibility builder = LeannBuilder( backend_name="hnsw", - embedding_model="Qwen/Qwen3-Embedding-0.6B", - graph_degree=32, + embedding_model="Qwen/Qwen3-Embedding-0.6B", + graph_degree=32, complexity=64, is_compact=True, is_recompute=True, - num_threads=1 # Force single-threaded mode + num_threads=1, # Force single-threaded mode ) print(f"Adding {len(all_texts)} chat chunks to index...") for chunk_text in all_texts: builder.add_text(chunk_text) - + builder.build_index(index_path) print(f"\nLEANN index built at {index_path}!") else: print(f"--- Using existing index at {INDEX_DIR} ---") - + return index_path -def create_leann_index(export_dir: str = None, index_path: str = "wechat_history_index.leann", max_count: int = 1000): + +def create_leann_index( + export_dir: str = None, + index_path: str = "wechat_history_index.leann", + max_count: int = 1000, +): """ Create LEANN index from WeChat chat history data. - + Args: export_dir: Path to the WeChat export directory (optional, uses default if None) index_path: Path to save the LEANN index @@ -121,34 +138,35 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history """ print("Creating LEANN index from WeChat chat history data...") INDEX_DIR = Path(index_path).parent - + if not INDEX_DIR.exists(): print(f"--- Index directory not found, building new index ---") INDEX_DIR.mkdir(exist_ok=True) print(f"--- Building new LEANN index ---") - + print(f"\n[PHASE 1] Building Leann index...") # Load documents using WeChatHistoryReader from history_data from history_data.wechat_history import WeChatHistoryReader + reader = WeChatHistoryReader() - + documents = reader.load_data( wechat_export_dir=export_dir, max_count=max_count, - concatenate_messages=False # Disable concatenation - one message per document + concatenate_messages=False, # Disable concatenation - one message per document ) - + if not documents: print("No documents loaded. Exiting.") return None - + print(f"Loaded {len(documents)} chat documents") - + # Create text splitter with 256 chunk size text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) - + # Convert Documents to text strings and chunk them all_texts = [] for doc in documents: @@ -156,54 +174,55 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history nodes = text_splitter.get_nodes_from_documents([doc]) for node in nodes: all_texts.append(node.get_content()) - + print(f"Created {len(all_texts)} text chunks from {len(documents)} documents") - + # Create LEANN index directory print(f"--- Index directory not found, building new index ---") INDEX_DIR.mkdir(exist_ok=True) print(f"--- Building new LEANN index ---") - + print(f"\n[PHASE 1] Building Leann index...") # Use HNSW backend for better macOS compatibility builder = LeannBuilder( backend_name="hnsw", embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model - graph_degree=32, + graph_degree=32, complexity=64, is_compact=True, is_recompute=True, - num_threads=1 # Force single-threaded mode + num_threads=1, # Force single-threaded mode ) print(f"Adding {len(all_texts)} chat chunks to index...") for chunk_text in all_texts: builder.add_text(chunk_text) - + builder.build_index(index_path) print(f"\nLEANN index built at {index_path}!") else: print(f"--- Using existing index at {INDEX_DIR} ---") - + return index_path + async def query_leann_index(index_path: str, query: str): """ Query the LEANN index. - + Args: index_path: Path to the LEANN index query: The query string """ print(f"\n[PHASE 2] Starting Leann chat session...") chat = LeannChat(index_path=index_path) - + print(f"You: {query}") chat_response = chat.ask( - query, - top_k=5, + query, + top_k=5, recompute_beighbor_embeddings=True, complexity=32, beam_width=1, @@ -212,52 +231,74 @@ async def query_leann_index(index_path: str, query: str): "model": "gpt-4o", "api_key": os.getenv("OPENAI_API_KEY"), }, - llm_kwargs={ - "temperature": 0.0, - "max_tokens": 1000 - } + llm_kwargs={"temperature": 0.0, "max_tokens": 1000}, ) print(f"Leann: {chat_response}") + async def main(): """Main function with integrated WeChat export functionality.""" - + # Parse command line arguments - parser = argparse.ArgumentParser(description='LEANN WeChat History Reader - Create and query WeChat chat history index') - parser.add_argument('--export-dir', type=str, default=DEFAULT_WECHAT_EXPORT_DIR, - help=f'Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})') - parser.add_argument('--index-dir', type=str, default="./wechat_history_index_leann_test", - help='Directory to store the LEANN index (default: ./wechat_history_index_leann_test)') - parser.add_argument('--max-entries', type=int, default=5000, - help='Maximum number of chat entries to process (default: 5000)') - parser.add_argument('--query', type=str, default=None, - help='Single query to run (default: runs example queries)') - parser.add_argument('--force-export', action='store_true', default=False, - help='Force re-export of WeChat data even if exports exist') - + parser = argparse.ArgumentParser( + description="LEANN WeChat History Reader - Create and query WeChat chat history index" + ) + parser.add_argument( + "--export-dir", + type=str, + default=DEFAULT_WECHAT_EXPORT_DIR, + help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})", + ) + parser.add_argument( + "--index-dir", + type=str, + default="./wechat_history_index_leann_test", + help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)", + ) + parser.add_argument( + "--max-entries", + type=int, + default=5000, + help="Maximum number of chat entries to process (default: 5000)", + ) + parser.add_argument( + "--query", + type=str, + default=None, + help="Single query to run (default: runs example queries)", + ) + parser.add_argument( + "--force-export", + action="store_true", + default=False, + help="Force re-export of WeChat data even if exports exist", + ) + args = parser.parse_args() - + INDEX_DIR = Path(args.index_dir) INDEX_PATH = str(INDEX_DIR / "wechat_history.leann") - + print(f"Using WeChat export directory: {args.export_dir}") print(f"Index directory: {INDEX_DIR}") print(f"Max entries: {args.max_entries}") - + # Initialize WeChat reader with export capabilities from history_data.wechat_history import WeChatHistoryReader + reader = WeChatHistoryReader() - + # Find existing exports or create new ones using the centralized method export_dirs = reader.find_or_export_wechat_data(args.export_dir) - if not export_dirs: print("Failed to find or export WeChat data. Exiting.") return - + # Create or load the LEANN index from all sources - index_path = create_leann_index_from_multiple_wechat_exports(export_dirs, INDEX_PATH, max_count=args.max_entries) - + index_path = create_leann_index_from_multiple_wechat_exports( + export_dirs, INDEX_PATH, max_count=args.max_entries + ) + if index_path: if args.query: # Run single query @@ -267,10 +308,11 @@ async def main(): queries = [ "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?", ] - + for query in queries: - print("\n" + "="*60) + print("\n" + "=" * 60) await query_leann_index(index_path, query) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index 49442d0..f7a59cd 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -14,6 +14,7 @@ import os from contextlib import contextmanager import zmq import numpy as np +import msgpack from pathlib import Path RED = "\033[91m" @@ -26,6 +27,7 @@ class SimplePassageLoader: """ def __init__(self, passages_data: Optional[Dict[str, Any]] = None): self.passages_data = passages_data or {} + self._meta_path = '' def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: """Get passage by ID""" @@ -38,6 +40,9 @@ class SimplePassageLoader: def __len__(self) -> int: return len(self.passages_data) + + def keys(self): + return self.passages_data.keys() def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: """ @@ -101,8 +106,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: def __len__(self) -> int: return len(self.label_map) + + def keys(self): + return self.label_map.keys() - return LazyPassageLoader(passage_manager, label_map) + loader = LazyPassageLoader(passage_manager, label_map) + loader._meta_path = meta_file + return loader def load_passages_from_file(passages_file: str) -> SimplePassageLoader: """ @@ -353,6 +363,100 @@ def create_embedding_server_thread( continue print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes") + # Handle control messages (MessagePack format) + try: + request_payload = msgpack.unpackb(message) + if isinstance(request_payload, list) and len(request_payload) >= 1: + if request_payload[0] == "__QUERY_META_PATH__": + # Return the current meta path being used by the server + current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else '' + response = [current_meta_path] + socket.send_multipart([identity, b'', msgpack.packb(response)]) + continue + + elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2: + # Update the server's meta path and reload passages + new_meta_path = request_payload[1] + try: + print(f"INFO: Updating server meta path to: {new_meta_path}") + # Reload passages from the new meta file + passages = load_passages_from_metadata(new_meta_path) + # Store the meta path for future queries + passages._meta_path = new_meta_path + response = ["SUCCESS"] + print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages") + except Exception as e: + print(f"ERROR: Failed to update meta path: {e}") + response = ["FAILED", str(e)] + socket.send_multipart([identity, b'', msgpack.packb(response)]) + continue + + elif request_payload[0] == "__QUERY_MODEL__": + # Return the current model being used by the server + response = [model_name] + socket.send_multipart([identity, b'', msgpack.packb(response)]) + continue + + elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2: + # Update the server's embedding model + new_model_name = request_payload[1] + try: + print(f"INFO: Updating server model from {model_name} to: {new_model_name}") + + # Clean up old model to free memory + if not use_mlx: + print("INFO: Releasing old model from memory...") + old_model = model + old_tokenizer = tokenizer + + # Load new tokenizer first + print(f"Loading new tokenizer for {new_model_name}...") + tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True) + + # Load new model + print(f"Loading new model {new_model_name}...") + model = AutoModel.from_pretrained(new_model_name).to(device).eval() + + # Optimize new model + if cuda_available or mps_available: + try: + model = model.half() + model = torch.compile(model) + print(f"INFO: Using FP16 precision with model: {new_model_name}") + except Exception as e: + print(f"WARNING: Model optimization failed: {e}") + + # Now safely delete old model after new one is loaded + del old_model + del old_tokenizer + + # Clear GPU cache if available + if device.type == "cuda": + torch.cuda.empty_cache() + print("INFO: Cleared CUDA cache") + elif device.type == "mps": + torch.mps.empty_cache() + print("INFO: Cleared MPS cache") + + # Force garbage collection + import gc + gc.collect() + print("INFO: Memory cleanup completed") + + # Update model name + model_name = new_model_name + + response = ["SUCCESS"] + print(f"INFO: Successfully updated model to: {new_model_name}") + except Exception as e: + print(f"ERROR: Failed to update model: {e}") + response = ["FAILED", str(e)] + socket.send_multipart([identity, b'', msgpack.packb(response)]) + continue + except: + # Not a control message, continue with normal protobuf processing + pass + e2e_start = time.time() lookup_timer = DeviceTimer("text lookup") diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index e5d06f4..a27140f 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -17,10 +17,12 @@ import msgpack import json from pathlib import Path from typing import Dict, Any, Optional, Union +import sys RED = "\033[91m" RESET = "\033[0m" + def is_similarity_metric(): """ Check if the metric type is similarity-based (like inner product). @@ -28,22 +30,27 @@ def is_similarity_metric(): """ return True # 1 is METRIC_INNER_PRODUCT in FAISS + # Function for E5-style average pooling import torch from torch import Tensor import torch.nn.functional as F + def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + class SimplePassageLoader: """ Simple passage loader that replaces config.py dependencies """ + def __init__(self, passages_data: Optional[Dict[str, Any]] = None): self.passages_data = passages_data or {} - + self._meta_path = "" + def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: """Get passage by ID""" str_id = str(passage_id) @@ -52,54 +59,57 @@ class SimplePassageLoader: else: # Return empty text for missing passages return {"text": ""} - + def __len__(self) -> int: return len(self.passages_data) + def keys(self): + return self.passages_data.keys() + + def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: """ Load passages using metadata file with PassageManager for lazy loading """ # Load metadata to get passage sources - with open(meta_file, 'r') as f: + with open(meta_file, "r") as f: meta = json.load(f) - + # Import PassageManager dynamically to avoid circular imports - import sys - import importlib.util - # Find the leann package directory relative to this file current_dir = Path(__file__).parent leann_core_path = current_dir.parent.parent / "leann-core" / "src" sys.path.insert(0, str(leann_core_path)) - + try: from leann.api import PassageManager - passage_manager = PassageManager(meta['passage_sources']) + + passage_manager = PassageManager(meta["passage_sources"]) finally: sys.path.pop(0) - - # Load label map + + # Load label map passages_dir = Path(meta_file).parent label_map_file = passages_dir / "leann.labels.map" - + if label_map_file.exists(): import pickle - with open(label_map_file, 'rb') as f: + + with open(label_map_file, "rb") as f: label_map = pickle.load(f) print(f"Loaded label map with {len(label_map)} entries") else: raise FileNotFoundError(f"Label map file not found: {label_map_file}") - + print(f"Initialized lazy passage loading for {len(label_map)} passages") - + class LazyPassageLoader(SimplePassageLoader): def __init__(self, passage_manager, label_map): self.passage_manager = passage_manager self.label_map = label_map # Initialize parent with empty data super().__init__({}) - + def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: """Get passage by ID with lazy loading""" try: @@ -118,12 +128,16 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: except Exception as e: print(f"DEBUG: Exception getting passage {passage_id}: {e}") return {"text": ""} - + def __len__(self) -> int: return len(self.label_map) - + + def keys(self): + return self.label_map.keys() + return LazyPassageLoader(passage_manager, label_map) + def create_hnsw_embedding_server( passages_file: Optional[str] = None, passages_data: Optional[Dict[str, str]] = None, @@ -139,7 +153,7 @@ def create_hnsw_embedding_server( ): """ Create and start a ZMQ-based embedding server for HNSW backend. - + Args: passages_file: Path to JSON file containing passage ID -> text mapping passages_data: Direct passage data dict (alternative to passages_file) @@ -156,14 +170,14 @@ def create_hnsw_embedding_server( print(f"Loading tokenizer for {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) print(f"Tokenizer loaded successfully!") - + # Device setup - mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() cuda_available = torch.cuda.is_available() - + print(f"MPS available: {mps_available}") print(f"CUDA available: {cuda_available}") - + if cuda_available: device = torch.device("cuda") print("Using CUDA device") @@ -173,7 +187,7 @@ def create_hnsw_embedding_server( else: device = torch.device("cpu") print("Using CPU device (no GPU acceleration available)") - + # Load model to the appropriate device print(f"Starting HNSW server on port {zmq_port} with model {model_name}") print(f"Loading model {model_name}... (this may take a while if downloading)") @@ -182,9 +196,10 @@ def create_hnsw_embedding_server( # Check port availability import socket + def check_port(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 if check_port(zmq_port): print(f"{RED}Port {zmq_port} is already in use{RESET}") @@ -196,8 +211,14 @@ def create_hnsw_embedding_server( model = torch.compile(model) print(f"Using FP16 precision with model: {model_name}") elif use_int8: - print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization") - from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig + print( + "- Using TorchAO for Int8 dynamic activation and Int8 weight quantization" + ) + from torchao.quantization import ( + quantize_, + Int8DynamicActivationInt8WeightConfig, + ) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) model = torch.compile(model) model.eval() @@ -209,8 +230,10 @@ def create_hnsw_embedding_server( print(f"Using provided passages data: {len(passages)} passages") elif passages_file: # Check if it's a metadata file or a single passages file - if passages_file.endswith('.meta.json'): + if passages_file.endswith(".meta.json"): passages = load_passages_from_metadata(passages_file) + # Store the meta path for future reference + passages._meta_path = passages_file else: # Try to find metadata file in same directory passages_dir = Path(passages_file).parent @@ -220,8 +243,12 @@ def create_hnsw_embedding_server( passages = load_passages_from_metadata(str(meta_files[0])) else: # Fallback to original single file loading (will cause warnings) - print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)") - passages = SimplePassageLoader() # Use empty loader to avoid massive warnings + print( + "WARNING: No metadata file found, using single file loading (may cause missing passage warnings)" + ) + passages = ( + SimplePassageLoader() + ) # Use empty loader to avoid massive warnings else: passages = SimplePassageLoader() print("No passages provided, using empty loader") @@ -238,12 +265,13 @@ def create_hnsw_embedding_server( class DeviceTimer: """Device event-based timer for accurate timing.""" + def __init__(self, name="", device=device): self.name = name self.device = device self.start_time = 0 self.end_time = 0 - + if cuda_available: self.start_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True) @@ -289,30 +317,31 @@ def create_hnsw_embedding_server( _is_e5_model = "e5" in model_name.lower() _is_bge_model = "bge" in model_name.lower() batch_size = len(texts_batch) - - # Validate no empty texts - for i, text in enumerate(texts_batch): - if not text or text.strip() == "": - raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}") - + + # Allow empty texts to pass through (remove validation) + # E5 model preprocessing if _is_e5_model: processed_texts_batch = [f"passage: {text}" for text in texts_batch] else: processed_texts_batch = texts_batch - + # Set max length if _is_e5_model: - current_max_length = custom_max_length_param if custom_max_length_param is not None else 512 + current_max_length = ( + custom_max_length_param if custom_max_length_param is not None else 512 + ) else: - current_max_length = custom_max_length_param if custom_max_length_param is not None else 256 - + current_max_length = ( + custom_max_length_param if custom_max_length_param is not None else 256 + ) + tokenize_timer = DeviceTimer("tokenization (batch)", device) to_device_timer = DeviceTimer("transfer to device (batch)", device) embed_timer = DeviceTimer("embedding (batch)", device) pool_timer = DeviceTimer("pooling (batch)", device) norm_timer = DeviceTimer("normalization (batch)", device) - + with tokenize_timer.timing(): encoded_batch = tokenizer( processed_texts_batch, @@ -322,48 +351,71 @@ def create_hnsw_embedding_server( return_tensors="pt", return_token_type_ids=False, ) - + seq_length = encoded_batch["input_ids"].size(1) - + with to_device_timer.timing(): enc = {k: v.to(device) for k, v in encoded_batch.items()} - + with torch.no_grad(): with embed_timer.timing(): out = model(enc["input_ids"], enc["attention_mask"]) - + with pool_timer.timing(): if _is_bge_model: pooled_embeddings = out.last_hidden_state[:, 0] - elif not hasattr(out, 'last_hidden_state'): + elif not hasattr(out, "last_hidden_state"): if isinstance(out, torch.Tensor) and len(out.shape) == 2: pooled_embeddings = out else: - print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}") - hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768) - pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32) + print( + f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}" + ) + hidden_dim = getattr( + model.config, "hidden_size", 384 if _is_e5_model else 768 + ) + pooled_embeddings = torch.zeros( + (batch_size, hidden_dim), + device=device, + dtype=enc["input_ids"].dtype + if hasattr(enc["input_ids"], "dtype") + else torch.float32, + ) elif _is_e5_model: - pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask']) + pooled_embeddings = e5_average_pool( + out.last_hidden_state, enc["attention_mask"] + ) else: hidden_states = out.last_hidden_state - mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float() + mask_expanded = ( + enc["attention_mask"] + .unsqueeze(-1) + .expand(hidden_states.size()) + .float() + ) sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) pooled_embeddings = sum_embeddings / sum_mask - + final_embeddings = pooled_embeddings if _is_e5_model or _is_bge_model: with norm_timer.timing(): final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1) - + if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any(): - print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! " - f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}") + print( + f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! " + f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}" + ) dim_size = final_embeddings.shape[-1] - error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy() - print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}") + error_output = torch.zeros( + (batch_size, dim_size), device="cpu", dtype=torch.float32 + ).numpy() + print( + f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}" + ) return error_output - + return final_embeddings.cpu().numpy() def client_warmup(zmq_port): @@ -371,7 +423,7 @@ def create_hnsw_embedding_server( time.sleep(2) print(f"Performing client-side warmup with model {model_name}...") sample_ids = ["1", "2", "3", "4", "5"] - + try: context = zmq.Context() socket = context.socket(zmq.REQ) @@ -379,12 +431,12 @@ def create_hnsw_embedding_server( socket.setsockopt(zmq.RCVTIMEO, 30000) socket.setsockopt(zmq.SNDTIMEO, 30000) - try: + try: ids_to_send = [int(x) for x in sample_ids] - except ValueError: + except ValueError: ids_to_send = [] - if not ids_to_send: + if not ids_to_send: print("Skipping warmup send.") return @@ -392,14 +444,18 @@ def create_hnsw_embedding_server( request_bytes = msgpack.packb(request_payload) for i in range(3): - print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...") + print(f"Sending warmup request {i + 1}/3 via ZMQ (MessagePack)...") socket.send(request_bytes) response_bytes = socket.recv() response_payload = msgpack.unpackb(response_bytes) dimensions = response_payload[0] - embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0 - print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings") + embeddings_count = ( + dimensions[0] if dimensions and len(dimensions) > 0 else 0 + ) + print( + f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings" + ) time.sleep(0.1) print("Client-side MessagePack ZMQ warmup complete") @@ -410,6 +466,7 @@ def create_hnsw_embedding_server( def zmq_server_thread(): """ZMQ server thread""" + nonlocal passages, model, tokenizer, model_name context = zmq.Context() socket = context.socket(zmq.REP) socket.bind(f"tcp://*:{zmq_port}") @@ -428,94 +485,277 @@ def create_hnsw_embedding_server( try: request_payload = msgpack.unpackb(message_bytes) - + print(f"DEBUG: Raw request_payload: {request_payload}") + print(f"DEBUG: request_payload type: {type(request_payload)}") + if isinstance(request_payload, list): + print(f"DEBUG: request_payload length: {len(request_payload)}") + for i, item in enumerate(request_payload): + print( + f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}" + ) + + # Handle control messages for meta path and model management + if isinstance(request_payload, list) and len(request_payload) >= 1: + if request_payload[0] == "__QUERY_META_PATH__": + # Return the current meta path being used by the server + current_meta_path = ( + getattr(passages, "_meta_path", "") + if hasattr(passages, "_meta_path") + else "" + ) + response = [current_meta_path] + socket.send(msgpack.packb(response)) + continue + + elif ( + request_payload[0] == "__UPDATE_META_PATH__" + and len(request_payload) >= 2 + ): + # Update the server's meta path and reload passages + new_meta_path = request_payload[1] + try: + print( + f"INFO: Updating server meta path to: {new_meta_path}" + ) + # Reload passages from the new meta file + passages = load_passages_from_metadata(new_meta_path) + # Store the meta path for future queries + passages._meta_path = new_meta_path + response = ["SUCCESS"] + print( + f"INFO: Successfully updated meta path and reloaded {len(passages)} passages" + ) + except Exception as e: + print(f"ERROR: Failed to update meta path: {e}") + response = ["FAILED", str(e)] + socket.send(msgpack.packb(response)) + continue + + elif request_payload[0] == "__QUERY_MODEL__": + # Return the current model being used by the server + response = [model_name] + socket.send(msgpack.packb(response)) + continue + + elif ( + request_payload[0] == "__UPDATE_MODEL__" + and len(request_payload) >= 2 + ): + # Update the server's embedding model + new_model_name = request_payload[1] + try: + print( + f"INFO: Updating server model from {model_name} to: {new_model_name}" + ) + + # Clean up old model to free memory + print("INFO: Releasing old model from memory...") + old_model = model + old_tokenizer = tokenizer + + # Load new tokenizer first + print(f"Loading new tokenizer for {new_model_name}...") + tokenizer = AutoTokenizer.from_pretrained( + new_model_name, use_fast=True + ) + + # Load new model + print(f"Loading new model {new_model_name}...") + model = AutoModel.from_pretrained(new_model_name) + model.to(device) + model.eval() + + # Now safely delete old model after new one is loaded + del old_model + del old_tokenizer + + # Clear GPU cache if available + if device.type == "cuda": + torch.cuda.empty_cache() + print("INFO: Cleared CUDA cache") + elif device.type == "mps": + torch.mps.empty_cache() + print("INFO: Cleared MPS cache") + + # Update model name + model_name = new_model_name + + # Force garbage collection + import gc + + gc.collect() + print("INFO: Memory cleanup completed") + + response = ["SUCCESS"] + print( + f"INFO: Successfully updated model to: {new_model_name}" + ) + except Exception as e: + print(f"ERROR: Failed to update model: {e}") + response = ["FAILED", str(e)] + socket.send(msgpack.packb(response)) + continue + # Handle distance calculation requests - if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list): + if ( + isinstance(request_payload, list) + and len(request_payload) == 2 + and isinstance(request_payload[0], list) + and isinstance(request_payload[1], list) + ): node_ids = request_payload[0] query_vector = np.array(request_payload[1], dtype=np.float32) - - print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}") - + + print("DEBUG: Distance calculation request received") + print(f" Node IDs: {node_ids}") + print(f" Query vector dim: {len(query_vector)}") + print(f" Passages loaded: {len(passages)}") + # Get embeddings for node IDs texts = [] missing_ids = [] with lookup_timer.timing(): for nid in node_ids: print(f"DEBUG: Looking up passage ID {nid}") - txtinfo = passages[nid] - if txtinfo is None or txtinfo["text"] == "": - raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text") - txt = txtinfo["text"] - print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}") - texts.append(txt) + try: + txtinfo = passages[nid] + if txtinfo is None: + print( + f"ERROR: Passage with ID {nid} returned None" + ) + print(f"ERROR: txtinfo: {txtinfo}") + raise RuntimeError( + f"FATAL: Passage with ID {nid} returned None" + ) + txt = txtinfo[ + "text" + ] # Allow empty text to pass through + print( + f"DEBUG: Found text for ID {nid}, length: {len(txt)}" + ) + texts.append(txt) + except KeyError: + print( + f"ERROR: Passage ID {nid} not found in passages dict" + ) + print( + f"ERROR: Available passage IDs: {list(passages.keys())[:10]}..." + ) + raise RuntimeError( + f"FATAL: Passage with ID {nid} not found" + ) + except Exception as e: + print( + f"ERROR: Exception looking up passage ID {nid}: {e}" + ) + raise lookup_timer.print_elapsed() - + # Process embeddings in chunks if needed all_node_embeddings = [] total_size = len(texts) - + if total_size > max_batch_size: for i in range(0, total_size, max_batch_size): end_idx = min(i + max_batch_size, total_size) chunk_texts = texts[i:end_idx] chunk_ids = node_ids[i:end_idx] - - embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids) + + embeddings_chunk = process_batch( + chunk_texts, chunk_ids, missing_ids + ) all_node_embeddings.append(embeddings_chunk) - + if cuda_available: torch.cuda.empty_cache() elif device.type == "mps": torch.mps.empty_cache() - + node_embeddings = np.vstack(all_node_embeddings) else: - node_embeddings = process_batch(texts, node_ids, missing_ids) - + node_embeddings = process_batch( + texts, node_ids, missing_ids + ) + # Calculate distances query_tensor = torch.tensor(query_vector, device=device).float() - node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float() - + node_embeddings_tensor = torch.tensor( + node_embeddings, device=device + ).float() + calc_timer = DeviceTimer("distance calculation", device) with calc_timer.timing(): with torch.no_grad(): if distance_metric == "l2": - node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32) - query_np = query_tensor.cpu().numpy().astype(np.float32) - distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1) - else: # mips or cosine - node_embeddings_np = node_embeddings_tensor.cpu().numpy() + node_embeddings_np = ( + node_embeddings_tensor.cpu() + .numpy() + .astype(np.float32) + ) + query_np = ( + query_tensor.cpu().numpy().astype(np.float32) + ) + distances = np.sum( + np.square( + node_embeddings_np - query_np.reshape(1, -1) + ), + axis=1, + ) + else: # mips or cosine + node_embeddings_np = ( + node_embeddings_tensor.cpu().numpy() + ) query_np = query_tensor.cpu().numpy() distances = -np.dot(node_embeddings_np, query_np) calc_timer.print_elapsed() - + try: response_payload = distances.flatten().tolist() - response_bytes = msgpack.packb([response_payload], use_single_float=True) - print(f"Sending distance response with {len(distances)} distances") + response_bytes = msgpack.packb( + [response_payload], use_single_float=True + ) + print( + f"Sending distance response with {len(distances)} distances" + ) except Exception as pack_error: - print(f"Error packing MessagePack distance response: {pack_error}") + print( + f"ERROR: Error packing MessagePack distance response: {pack_error}" + ) + print(f"ERROR: distances shape: {distances.shape}") + print(f"ERROR: distances dtype: {distances.dtype}") + print(f"ERROR: distances content: {distances}") + print(f"ERROR: node_ids: {node_ids}") + print(f"ERROR: query_vector shape: {query_vector.shape}") + # Still return empty for now but with full error info response_bytes = msgpack.packb([[]]) - + socket.send(response_bytes) - + if device.type == "cuda": torch.cuda.synchronize() elif device.type == "mps": torch.mps.synchronize() e2e_end = time.time() - print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds") + print( + f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds" + ) continue - + # Standard embedding request - if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list): - print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}") + if ( + not isinstance(request_payload, list) + or len(request_payload) != 1 + or not isinstance(request_payload[0], list) + ): + print( + f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}" + ) socket.send(msgpack.packb([[], []])) continue - + node_ids = request_payload[0] print(f"Request for {len(node_ids)} node embeddings") - + except Exception as unpack_error: print(f"Error unpacking MessagePack request: {unpack_error}") socket.send(msgpack.packb([[], []])) @@ -529,11 +769,15 @@ def create_hnsw_embedding_server( try: txtinfo = passages[nid] if txtinfo is None or txtinfo["text"] == "": - raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") + raise RuntimeError( + f"FATAL: Passage with ID {nid} not found - failing fast" + ) else: txt = txtinfo["text"] except (KeyError, IndexError): - raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") + raise RuntimeError( + f"FATAL: Passage with ID {nid} not found - failing fast" + ) texts.append(txt) lookup_timer.print_elapsed() @@ -542,27 +786,35 @@ def create_hnsw_embedding_server( # Process in chunks total_size = len(texts) - print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}") - + print( + f"Total batch size: {total_size}, max_batch_size: {max_batch_size}" + ) + all_embeddings = [] - + if total_size > max_batch_size: - print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}") + print( + f"Splitting batch of size {total_size} into chunks of {max_batch_size}" + ) for i in range(0, total_size, max_batch_size): end_idx = min(i + max_batch_size, total_size) - print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}") - + print( + f"Processing chunk {i // max_batch_size + 1}/{(total_size + max_batch_size - 1) // max_batch_size}: items {i} to {end_idx - 1}" + ) + chunk_texts = texts[i:end_idx] chunk_ids = node_ids[i:end_idx] - - embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids) + + embeddings_chunk = process_batch( + chunk_texts, chunk_ids, missing_ids + ) all_embeddings.append(embeddings_chunk) - + if cuda_available: torch.cuda.empty_cache() elif device.type == "mps": torch.mps.empty_cache() - + hidden = np.vstack(all_embeddings) print(f"Combined embeddings shape: {hidden.shape}") else: @@ -571,22 +823,30 @@ def create_hnsw_embedding_server( # Serialization and response ser_start = time.time() - print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}") + print( + f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}" + ) if np.isnan(hidden).any() or np.isinf(hidden).any(): - print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! " - f"Requested IDs (sample): {node_ids[:5]}...{RESET}") + print( + f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! " + f"Requested IDs (sample): {node_ids[:5]}...{RESET}" + ) assert False try: - hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32) + hidden_contiguous_f32 = np.ascontiguousarray( + hidden, dtype=np.float32 + ) response_payload = [ list(hidden_contiguous_f32.shape), - hidden_contiguous_f32.flatten().tolist() + hidden_contiguous_f32.flatten().tolist(), ] - response_bytes = msgpack.packb(response_payload, use_single_float=True) + response_bytes = msgpack.packb( + response_payload, use_single_float=True + ) except Exception as pack_error: - print(f"Error packing MessagePack response: {pack_error}") - response_bytes = msgpack.packb([[], []]) + print(f"Error packing MessagePack response: {pack_error}") + response_bytes = msgpack.packb([[], []]) socket.send(response_bytes) ser_end = time.time() @@ -606,8 +866,9 @@ def create_hnsw_embedding_server( except Exception as e: print(f"Error in ZMQ server loop: {e}") import traceback + traceback.print_exc() - try: + try: socket.send(msgpack.packb([[], []])) except: pass @@ -621,7 +882,7 @@ def create_hnsw_embedding_server( zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread.start() print(f"Started HNSW ZMQ server thread on port {zmq_port}") - + # Keep the main thread alive try: while True: @@ -634,17 +895,41 @@ def create_hnsw_embedding_server( if __name__ == "__main__": parser = argparse.ArgumentParser(description="HNSW Embedding service") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") - parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping") - parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings") + parser.add_argument( + "--passages-file", + type=str, + help="JSON file containing passage ID to text mapping", + ) + parser.add_argument( + "--embeddings-file", + type=str, + help="Pickle file containing pre-computed embeddings", + ) parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-int8", action="store_true", default=False) parser.add_argument("--use-cuda-graphs", action="store_true", default=False) - parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting") - parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", - help="Embedding model name") - parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length") - parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use") - + parser.add_argument( + "--max-batch-size", + type=int, + default=128, + help="Maximum batch size before splitting", + ) + parser.add_argument( + "--model-name", + type=str, + default="sentence-transformers/all-mpnet-base-v2", + help="Embedding model name", + ) + parser.add_argument( + "--custom-max-length", + type=int, + default=None, + help="Override model's default max sequence length", + ) + parser.add_argument( + "--distance-metric", type=str, default="mips", help="Distance metric to use" + ) + args = parser.parse_args() # Create and start the HNSW embedding server @@ -659,4 +944,4 @@ if __name__ == "__main__": model_name=args.model_name, custom_max_length_param=args.custom_max_length, distance_metric=args.distance_metric, - ) \ No newline at end of file + ) diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 7205d87..9f5aaf3 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -1,23 +1,169 @@ - -import os import threading import time import atexit import socket import subprocess import sys +import zmq +import msgpack from pathlib import Path from typing import Optional +import select + def _check_port(port: int) -> bool: """Check if a port is in use""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 + + +def _check_server_meta_path(port: int, expected_meta_path: str) -> bool: + """ + Check if the existing server on the port is using the correct meta file. + Returns True if the server has the right meta path, False otherwise. + """ + try: + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout + socket.connect(f"tcp://localhost:{port}") + + # Send a special control message to query the server's meta path + control_request = ["__QUERY_META_PATH__"] + request_bytes = msgpack.packb(control_request) + socket.send(request_bytes) + + # Wait for response + response_bytes = socket.recv() + response = msgpack.unpackb(response_bytes) + + socket.close() + context.term() + + # Check if the response contains the meta path and if it matches + if isinstance(response, list) and len(response) > 0: + server_meta_path = response[0] + # Normalize paths for comparison + expected_path = Path(expected_meta_path).resolve() + server_path = Path(server_meta_path).resolve() if server_meta_path else None + return server_path == expected_path + + return False + + except Exception as e: + print(f"WARNING: Could not query server meta path on port {port}: {e}") + return False + + +def _update_server_meta_path(port: int, new_meta_path: str) -> bool: + """ + Send a control message to update the server's meta path. + Returns True if successful, False otherwise. + """ + try: + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout + socket.connect(f"tcp://localhost:{port}") + + # Send a control message to update the meta path + control_request = ["__UPDATE_META_PATH__", new_meta_path] + request_bytes = msgpack.packb(control_request) + socket.send(request_bytes) + + # Wait for response + response_bytes = socket.recv() + response = msgpack.unpackb(response_bytes) + + socket.close() + context.term() + + # Check if the update was successful + if isinstance(response, list) and len(response) > 0: + return response[0] == "SUCCESS" + + return False + + except Exception as e: + print(f"ERROR: Could not update server meta path on port {port}: {e}") + return False + + +def _check_server_model(port: int, expected_model: str) -> bool: + """ + Check if the existing server on the port is using the correct embedding model. + Returns True if the server has the right model, False otherwise. + """ + try: + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout + socket.connect(f"tcp://localhost:{port}") + + # Send a special control message to query the server's model + control_request = ["__QUERY_MODEL__"] + request_bytes = msgpack.packb(control_request) + socket.send(request_bytes) + + # Wait for response + response_bytes = socket.recv() + response = msgpack.unpackb(response_bytes) + + socket.close() + context.term() + + # Check if the response contains the model name and if it matches + if isinstance(response, list) and len(response) > 0: + server_model = response[0] + return server_model == expected_model + + return False + + except Exception as e: + print(f"WARNING: Could not query server model on port {port}: {e}") + return False + + +def _update_server_model(port: int, new_model: str) -> bool: + """ + Send a control message to update the server's embedding model. + Returns True if successful, False otherwise. + """ + try: + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading + socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending + socket.connect(f"tcp://localhost:{port}") + + # Send a control message to update the model + control_request = ["__UPDATE_MODEL__", new_model] + request_bytes = msgpack.packb(control_request) + socket.send(request_bytes) + + # Wait for response + response_bytes = socket.recv() + response = msgpack.unpackb(response_bytes) + + socket.close() + context.term() + + # Check if the update was successful + if isinstance(response, list) and len(response) > 0: + return response[0] == "SUCCESS" + + return False + + except Exception as e: + print(f"ERROR: Could not update server model on port {port}: {e}") + return False + class EmbeddingServerManager: """ A generic manager for handling the lifecycle of a backend-specific embedding server process. """ + def __init__(self, backend_module_name: str): """ Initializes the manager for a specific backend. @@ -44,21 +190,119 @@ class EmbeddingServerManager: bool: True if the server is started successfully or already running, False otherwise. """ if self.server_process and self.server_process.poll() is None: - print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})") - return True + # Even if we have a running process, check if model/meta path match + if self.server_port is not None: + port_in_use = _check_port(self.server_port) + if port_in_use: + print( + f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})" + ) + + # Check model compatibility + model_matches = _check_server_model(self.server_port, model_name) + if not model_matches: + print( + f"⚠️ Existing server has different model. Attempting to update to: {model_name}" + ) + if not _update_server_model(self.server_port, model_name): + print( + "❌ Failed to update existing server model. Restarting server..." + ) + self.stop_server() + # Continue to start new server below + else: + print( + f"✅ Successfully updated existing server model to: {model_name}" + ) + + # Also check meta path if provided + passages_file = kwargs.get("passages_file") + if passages_file and str(passages_file).endswith( + ".meta.json" + ): + meta_matches = _check_server_meta_path( + self.server_port, str(passages_file) + ) + if not meta_matches: + print("⚠️ Updating meta path to: {passages_file}") + _update_server_meta_path( + self.server_port, str(passages_file) + ) + + return True + else: + print( + f"✅ Existing server already using correct model: {model_name}" + ) + return True + else: + # Server process exists but port not responding - restart + print("⚠️ Server process exists but not responding. Restarting...") + self.stop_server() + # Continue to start new server below + else: + # No port stored - restart + print("⚠️ No port information stored. Restarting server...") + self.stop_server() + # Continue to start new server below if _check_port(port): - print(f"WARNING: Port {port} is already in use. Assuming an external server is running.") + # Port is in use, check if it's using the correct meta file and model + passages_file = kwargs.get("passages_file") + + print(f"INFO: Port {port} is in use. Checking server compatibility...") + + # Check model compatibility first + model_matches = _check_server_model(port, model_name) + if not model_matches: + print( + f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}" + ) + if not _update_server_model(port, model_name): + raise RuntimeError( + f"❌ Failed to update server model to {model_name}. Consider using a different port." + ) + print(f"✅ Successfully updated server model to: {model_name}") + else: + print( + f"✅ Existing server on port {port} is using correct model: {model_name}" + ) + + # Check meta path compatibility if provided + if passages_file and str(passages_file).endswith(".meta.json"): + meta_matches = _check_server_meta_path(port, str(passages_file)) + if not meta_matches: + print( + f"⚠️ Existing server on port {port} has different meta path. Attempting to update..." + ) + if not _update_server_meta_path(port, str(passages_file)): + raise RuntimeError( + "❌ Failed to update server meta path. This may cause data synchronization issues." + ) + print( + f"✅ Successfully updated server meta path to: {passages_file}" + ) + else: + print( + f"✅ Existing server on port {port} is using correct meta path: {passages_file}" + ) + + print(f"✅ Server on port {port} is compatible and ready to use.") return True - print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...") + print( + f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..." + ) try: command = [ sys.executable, - "-m", self.backend_module_name, - "--zmq-port", str(port), - "--model-name", model_name + "-m", + self.backend_module_name, + "--zmq-port", + str(port), + "--model-name", + model_name, ] # Add extra arguments for specific backends @@ -76,9 +320,9 @@ class EmbeddingServerManager: stdout=subprocess.PIPE, stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring text=True, - encoding='utf-8', + encoding="utf-8", bufsize=1, # Line buffered - universal_newlines=True + universal_newlines=True, ) self.server_port = port print(f"INFO: Server process started with PID: {self.server_process.pid}") @@ -86,17 +330,21 @@ class EmbeddingServerManager: max_wait, wait_interval = 120, 0.5 for _ in range(int(max_wait / wait_interval)): if _check_port(port): - print(f"✅ Embedding server is up and ready for this session.") + print("✅ Embedding server is up and ready for this session.") log_thread = threading.Thread(target=self._log_monitor, daemon=True) log_thread.start() return True if self.server_process.poll() is not None: - print("❌ ERROR: Server process terminated unexpectedly during startup.") + print( + "❌ ERROR: Server process terminated unexpectedly during startup." + ) self._print_recent_output() return False time.sleep(wait_interval) - print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.") + print( + f"❌ ERROR: Server process failed to start listening within {max_wait} seconds." + ) self.stop_server() return False @@ -110,8 +358,7 @@ class EmbeddingServerManager: return try: # Read any available output - import select - import sys + if select.select([self.server_process.stdout], [], [], 0)[0]: output = self.server_process.stdout.read() if output: @@ -129,19 +376,25 @@ class EmbeddingServerManager: line = self.server_process.stdout.readline() if not line: break - print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True) + print( + f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True + ) except Exception as e: print(f"Log monitor error: {e}") def stop_server(self): """Stops the embedding server process if it's running.""" if self.server_process and self.server_process.poll() is None: - print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...") + print( + f"INFO: Terminating session server process (PID: {self.server_process.pid})..." + ) self.server_process.terminate() try: self.server_process.wait(timeout=5) print("INFO: Server process terminated.") except subprocess.TimeoutExpired: - print("WARNING: Server process did not terminate gracefully, killing it.") + print( + "WARNING: Server process did not terminate gracefully, killing it." + ) self.server_process.kill() self.server_process = None diff --git a/pyproject.toml b/pyproject.toml index 6fcdbdf..3fa4b7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ "llama-index-node-parser-docling", "ipykernel==6.29.5", "msgpack>=1.1.1", + "llama-index-vector-stores-faiss>=0.4.0", + "llama-index-embeddings-huggingface>=0.5.5", ] [project.optional-dependencies]