refactor: check if current emb_server has correct passages/embedder
This commit is contained in:
@@ -14,7 +14,12 @@ dotenv.load_dotenv()
|
|||||||
# Default WeChat export directory
|
# Default WeChat export directory
|
||||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||||
|
|
||||||
def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], index_path: str = "wechat_history_index.leann", max_count: int = -1):
|
|
||||||
|
def create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs: List[Path],
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from multiple WeChat export data sources.
|
Create LEANN index from multiple WeChat export data sources.
|
||||||
|
|
||||||
@@ -27,6 +32,7 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
|
|||||||
|
|
||||||
# Load documents using WeChatHistoryReader from history_data
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
@@ -38,13 +44,15 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
|
|||||||
|
|
||||||
# Process each WeChat export directory
|
# Process each WeChat export directory
|
||||||
for i, export_dir in enumerate(export_dirs):
|
for i, export_dir in enumerate(export_dirs):
|
||||||
print(f"\nProcessing WeChat export {i+1}/{len(export_dirs)}: {export_dir}")
|
print(
|
||||||
|
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents = reader.load_data(
|
documents = reader.load_data(
|
||||||
wechat_export_dir=str(export_dir),
|
wechat_export_dir=str(export_dir),
|
||||||
max_count=max_count,
|
max_count=max_count,
|
||||||
concatenate_messages=False # Disable concatenation - one message per document
|
concatenate_messages=False, # Disable concatenation - one message per document
|
||||||
)
|
)
|
||||||
if documents:
|
if documents:
|
||||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
@@ -65,7 +73,9 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
|
|||||||
print("No documents loaded from any source. Exiting.")
|
print("No documents loaded from any source. Exiting.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||||
|
)
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
@@ -78,7 +88,9 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
|
|||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
print(
|
||||||
|
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||||
|
)
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print(f"--- Index directory not found, building new index ---")
|
||||||
@@ -96,7 +108,7 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
|
|||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
@@ -110,7 +122,12 @@ def create_leann_index_from_multiple_wechat_exports(export_dirs: List[Path], ind
|
|||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
def create_leann_index(export_dir: str = None, index_path: str = "wechat_history_index.leann", max_count: int = 1000):
|
|
||||||
|
def create_leann_index(
|
||||||
|
export_dir: str = None,
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from WeChat chat history data.
|
Create LEANN index from WeChat chat history data.
|
||||||
|
|
||||||
@@ -132,12 +149,13 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history
|
|||||||
|
|
||||||
# Load documents using WeChatHistoryReader from history_data
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
documents = reader.load_data(
|
documents = reader.load_data(
|
||||||
wechat_export_dir=export_dir,
|
wechat_export_dir=export_dir,
|
||||||
max_count=max_count,
|
max_count=max_count,
|
||||||
concatenate_messages=False # Disable concatenation - one message per document
|
concatenate_messages=False, # Disable concatenation - one message per document
|
||||||
)
|
)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
@@ -175,7 +193,7 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history
|
|||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
@@ -189,6 +207,7 @@ def create_leann_index(export_dir: str = None, index_path: str = "wechat_history
|
|||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
async def query_leann_index(index_path: str, query: str):
|
||||||
"""
|
"""
|
||||||
Query the LEANN index.
|
Query the LEANN index.
|
||||||
@@ -212,28 +231,48 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
},
|
},
|
||||||
llm_kwargs={
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": 1000
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""Main function with integrated WeChat export functionality."""
|
"""Main function with integrated WeChat export functionality."""
|
||||||
|
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='LEANN WeChat History Reader - Create and query WeChat chat history index')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--export-dir', type=str, default=DEFAULT_WECHAT_EXPORT_DIR,
|
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||||
help=f'Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})')
|
)
|
||||||
parser.add_argument('--index-dir', type=str, default="./wechat_history_index_leann_test",
|
parser.add_argument(
|
||||||
help='Directory to store the LEANN index (default: ./wechat_history_index_leann_test)')
|
"--export-dir",
|
||||||
parser.add_argument('--max-entries', type=int, default=5000,
|
type=str,
|
||||||
help='Maximum number of chat entries to process (default: 5000)')
|
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||||
parser.add_argument('--query', type=str, default=None,
|
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||||
help='Single query to run (default: runs example queries)')
|
)
|
||||||
parser.add_argument('--force-export', action='store_true', default=False,
|
parser.add_argument(
|
||||||
help='Force re-export of WeChat data even if exports exist')
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_history_index_leann_test",
|
||||||
|
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=5000,
|
||||||
|
help="Maximum number of chat entries to process (default: 5000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -246,17 +285,19 @@ async def main():
|
|||||||
|
|
||||||
# Initialize WeChat reader with export capabilities
|
# Initialize WeChat reader with export capabilities
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
# Find existing exports or create new ones using the centralized method
|
# Find existing exports or create new ones using the centralized method
|
||||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
|
||||||
if not export_dirs:
|
if not export_dirs:
|
||||||
print("Failed to find or export WeChat data. Exiting.")
|
print("Failed to find or export WeChat data. Exiting.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
# Create or load the LEANN index from all sources
|
||||||
index_path = create_leann_index_from_multiple_wechat_exports(export_dirs, INDEX_PATH, max_count=args.max_entries)
|
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||||
|
)
|
||||||
|
|
||||||
if index_path:
|
if index_path:
|
||||||
if args.query:
|
if args.query:
|
||||||
@@ -269,8 +310,9 @@ async def main():
|
|||||||
]
|
]
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
await query_leann_index(index_path, query)
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -14,6 +14,7 @@ import os
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import zmq
|
import zmq
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import msgpack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
@@ -26,6 +27,7 @@ class SimplePassageLoader:
|
|||||||
"""
|
"""
|
||||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||||
self.passages_data = passages_data or {}
|
self.passages_data = passages_data or {}
|
||||||
|
self._meta_path = ''
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||||
"""Get passage by ID"""
|
"""Get passage by ID"""
|
||||||
@@ -39,6 +41,9 @@ class SimplePassageLoader:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.passages_data)
|
return len(self.passages_data)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.passages_data.keys()
|
||||||
|
|
||||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||||
"""
|
"""
|
||||||
Load passages using metadata file with PassageManager for lazy loading
|
Load passages using metadata file with PassageManager for lazy loading
|
||||||
@@ -102,7 +107,12 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.label_map)
|
return len(self.label_map)
|
||||||
|
|
||||||
return LazyPassageLoader(passage_manager, label_map)
|
def keys(self):
|
||||||
|
return self.label_map.keys()
|
||||||
|
|
||||||
|
loader = LazyPassageLoader(passage_manager, label_map)
|
||||||
|
loader._meta_path = meta_file
|
||||||
|
return loader
|
||||||
|
|
||||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||||
"""
|
"""
|
||||||
@@ -353,6 +363,100 @@ def create_embedding_server_thread(
|
|||||||
continue
|
continue
|
||||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||||
|
|
||||||
|
# Handle control messages (MessagePack format)
|
||||||
|
try:
|
||||||
|
request_payload = msgpack.unpackb(message)
|
||||||
|
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||||
|
if request_payload[0] == "__QUERY_META_PATH__":
|
||||||
|
# Return the current meta path being used by the server
|
||||||
|
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
||||||
|
response = [current_meta_path]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
||||||
|
# Update the server's meta path and reload passages
|
||||||
|
new_meta_path = request_payload[1]
|
||||||
|
try:
|
||||||
|
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
||||||
|
# Reload passages from the new meta file
|
||||||
|
passages = load_passages_from_metadata(new_meta_path)
|
||||||
|
# Store the meta path for future queries
|
||||||
|
passages._meta_path = new_meta_path
|
||||||
|
response = ["SUCCESS"]
|
||||||
|
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to update meta path: {e}")
|
||||||
|
response = ["FAILED", str(e)]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif request_payload[0] == "__QUERY_MODEL__":
|
||||||
|
# Return the current model being used by the server
|
||||||
|
response = [model_name]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
||||||
|
# Update the server's embedding model
|
||||||
|
new_model_name = request_payload[1]
|
||||||
|
try:
|
||||||
|
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
||||||
|
|
||||||
|
# Clean up old model to free memory
|
||||||
|
if not use_mlx:
|
||||||
|
print("INFO: Releasing old model from memory...")
|
||||||
|
old_model = model
|
||||||
|
old_tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Load new tokenizer first
|
||||||
|
print(f"Loading new tokenizer for {new_model_name}...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
||||||
|
|
||||||
|
# Load new model
|
||||||
|
print(f"Loading new model {new_model_name}...")
|
||||||
|
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
||||||
|
|
||||||
|
# Optimize new model
|
||||||
|
if cuda_available or mps_available:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WARNING: Model optimization failed: {e}")
|
||||||
|
|
||||||
|
# Now safely delete old model after new one is loaded
|
||||||
|
del old_model
|
||||||
|
del old_tokenizer
|
||||||
|
|
||||||
|
# Clear GPU cache if available
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("INFO: Cleared CUDA cache")
|
||||||
|
elif device.type == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
print("INFO: Cleared MPS cache")
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
print("INFO: Memory cleanup completed")
|
||||||
|
|
||||||
|
# Update model name
|
||||||
|
model_name = new_model_name
|
||||||
|
|
||||||
|
response = ["SUCCESS"]
|
||||||
|
print(f"INFO: Successfully updated model to: {new_model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to update model: {e}")
|
||||||
|
response = ["FAILED", str(e)]
|
||||||
|
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
# Not a control message, continue with normal protobuf processing
|
||||||
|
pass
|
||||||
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
lookup_timer = DeviceTimer("text lookup")
|
lookup_timer = DeviceTimer("text lookup")
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,12 @@ import msgpack
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Dict, Any, Optional, Union
|
||||||
|
import sys
|
||||||
|
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
def is_similarity_metric():
|
def is_similarity_metric():
|
||||||
"""
|
"""
|
||||||
Check if the metric type is similarity-based (like inner product).
|
Check if the metric type is similarity-based (like inner product).
|
||||||
@@ -28,21 +30,26 @@ def is_similarity_metric():
|
|||||||
"""
|
"""
|
||||||
return True # 1 is METRIC_INNER_PRODUCT in FAISS
|
return True # 1 is METRIC_INNER_PRODUCT in FAISS
|
||||||
|
|
||||||
|
|
||||||
# Function for E5-style average pooling
|
# Function for E5-style average pooling
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
|
||||||
|
|
||||||
class SimplePassageLoader:
|
class SimplePassageLoader:
|
||||||
"""
|
"""
|
||||||
Simple passage loader that replaces config.py dependencies
|
Simple passage loader that replaces config.py dependencies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||||
self.passages_data = passages_data or {}
|
self.passages_data = passages_data or {}
|
||||||
|
self._meta_path = ""
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||||
"""Get passage by ID"""
|
"""Get passage by ID"""
|
||||||
@@ -56,18 +63,19 @@ class SimplePassageLoader:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.passages_data)
|
return len(self.passages_data)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.passages_data.keys()
|
||||||
|
|
||||||
|
|
||||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||||
"""
|
"""
|
||||||
Load passages using metadata file with PassageManager for lazy loading
|
Load passages using metadata file with PassageManager for lazy loading
|
||||||
"""
|
"""
|
||||||
# Load metadata to get passage sources
|
# Load metadata to get passage sources
|
||||||
with open(meta_file, 'r') as f:
|
with open(meta_file, "r") as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Import PassageManager dynamically to avoid circular imports
|
# Import PassageManager dynamically to avoid circular imports
|
||||||
import sys
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
# Find the leann package directory relative to this file
|
# Find the leann package directory relative to this file
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
@@ -75,7 +83,8 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from leann.api import PassageManager
|
from leann.api import PassageManager
|
||||||
passage_manager = PassageManager(meta['passage_sources'])
|
|
||||||
|
passage_manager = PassageManager(meta["passage_sources"])
|
||||||
finally:
|
finally:
|
||||||
sys.path.pop(0)
|
sys.path.pop(0)
|
||||||
|
|
||||||
@@ -85,7 +94,8 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
|
|
||||||
if label_map_file.exists():
|
if label_map_file.exists():
|
||||||
import pickle
|
import pickle
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
|
with open(label_map_file, "rb") as f:
|
||||||
label_map = pickle.load(f)
|
label_map = pickle.load(f)
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
print(f"Loaded label map with {len(label_map)} entries")
|
||||||
else:
|
else:
|
||||||
@@ -122,8 +132,12 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.label_map)
|
return len(self.label_map)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.label_map.keys()
|
||||||
|
|
||||||
return LazyPassageLoader(passage_manager, label_map)
|
return LazyPassageLoader(passage_manager, label_map)
|
||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
passages_data: Optional[Dict[str, str]] = None,
|
passages_data: Optional[Dict[str, str]] = None,
|
||||||
@@ -158,7 +172,7 @@ def create_hnsw_embedding_server(
|
|||||||
print(f"Tokenizer loaded successfully!")
|
print(f"Tokenizer loaded successfully!")
|
||||||
|
|
||||||
# Device setup
|
# Device setup
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
cuda_available = torch.cuda.is_available()
|
cuda_available = torch.cuda.is_available()
|
||||||
|
|
||||||
print(f"MPS available: {mps_available}")
|
print(f"MPS available: {mps_available}")
|
||||||
@@ -182,9 +196,10 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
# Check port availability
|
# Check port availability
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
def check_port(port):
|
def check_port(port):
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
if check_port(zmq_port):
|
if check_port(zmq_port):
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||||
@@ -196,8 +211,14 @@ def create_hnsw_embedding_server(
|
|||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
print(f"Using FP16 precision with model: {model_name}")
|
print(f"Using FP16 precision with model: {model_name}")
|
||||||
elif use_int8:
|
elif use_int8:
|
||||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
print(
|
||||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
"- Using TorchAO for Int8 dynamic activation and Int8 weight quantization"
|
||||||
|
)
|
||||||
|
from torchao.quantization import (
|
||||||
|
quantize_,
|
||||||
|
Int8DynamicActivationInt8WeightConfig,
|
||||||
|
)
|
||||||
|
|
||||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -209,8 +230,10 @@ def create_hnsw_embedding_server(
|
|||||||
print(f"Using provided passages data: {len(passages)} passages")
|
print(f"Using provided passages data: {len(passages)} passages")
|
||||||
elif passages_file:
|
elif passages_file:
|
||||||
# Check if it's a metadata file or a single passages file
|
# Check if it's a metadata file or a single passages file
|
||||||
if passages_file.endswith('.meta.json'):
|
if passages_file.endswith(".meta.json"):
|
||||||
passages = load_passages_from_metadata(passages_file)
|
passages = load_passages_from_metadata(passages_file)
|
||||||
|
# Store the meta path for future reference
|
||||||
|
passages._meta_path = passages_file
|
||||||
else:
|
else:
|
||||||
# Try to find metadata file in same directory
|
# Try to find metadata file in same directory
|
||||||
passages_dir = Path(passages_file).parent
|
passages_dir = Path(passages_file).parent
|
||||||
@@ -220,8 +243,12 @@ def create_hnsw_embedding_server(
|
|||||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
passages = load_passages_from_metadata(str(meta_files[0]))
|
||||||
else:
|
else:
|
||||||
# Fallback to original single file loading (will cause warnings)
|
# Fallback to original single file loading (will cause warnings)
|
||||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
print(
|
||||||
passages = SimplePassageLoader() # Use empty loader to avoid massive warnings
|
"WARNING: No metadata file found, using single file loading (may cause missing passage warnings)"
|
||||||
|
)
|
||||||
|
passages = (
|
||||||
|
SimplePassageLoader()
|
||||||
|
) # Use empty loader to avoid massive warnings
|
||||||
else:
|
else:
|
||||||
passages = SimplePassageLoader()
|
passages = SimplePassageLoader()
|
||||||
print("No passages provided, using empty loader")
|
print("No passages provided, using empty loader")
|
||||||
@@ -238,6 +265,7 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
class DeviceTimer:
|
class DeviceTimer:
|
||||||
"""Device event-based timer for accurate timing."""
|
"""Device event-based timer for accurate timing."""
|
||||||
|
|
||||||
def __init__(self, name="", device=device):
|
def __init__(self, name="", device=device):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -290,10 +318,7 @@ def create_hnsw_embedding_server(
|
|||||||
_is_bge_model = "bge" in model_name.lower()
|
_is_bge_model = "bge" in model_name.lower()
|
||||||
batch_size = len(texts_batch)
|
batch_size = len(texts_batch)
|
||||||
|
|
||||||
# Validate no empty texts
|
# Allow empty texts to pass through (remove validation)
|
||||||
for i, text in enumerate(texts_batch):
|
|
||||||
if not text or text.strip() == "":
|
|
||||||
raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}")
|
|
||||||
|
|
||||||
# E5 model preprocessing
|
# E5 model preprocessing
|
||||||
if _is_e5_model:
|
if _is_e5_model:
|
||||||
@@ -303,9 +328,13 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
# Set max length
|
# Set max length
|
||||||
if _is_e5_model:
|
if _is_e5_model:
|
||||||
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512
|
current_max_length = (
|
||||||
|
custom_max_length_param if custom_max_length_param is not None else 512
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256
|
current_max_length = (
|
||||||
|
custom_max_length_param if custom_max_length_param is not None else 256
|
||||||
|
)
|
||||||
|
|
||||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
||||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
||||||
@@ -335,18 +364,35 @@ def create_hnsw_embedding_server(
|
|||||||
with pool_timer.timing():
|
with pool_timer.timing():
|
||||||
if _is_bge_model:
|
if _is_bge_model:
|
||||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||||
elif not hasattr(out, 'last_hidden_state'):
|
elif not hasattr(out, "last_hidden_state"):
|
||||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||||
pooled_embeddings = out
|
pooled_embeddings = out
|
||||||
else:
|
else:
|
||||||
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}")
|
print(
|
||||||
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768)
|
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
||||||
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32)
|
)
|
||||||
|
hidden_dim = getattr(
|
||||||
|
model.config, "hidden_size", 384 if _is_e5_model else 768
|
||||||
|
)
|
||||||
|
pooled_embeddings = torch.zeros(
|
||||||
|
(batch_size, hidden_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=enc["input_ids"].dtype
|
||||||
|
if hasattr(enc["input_ids"], "dtype")
|
||||||
|
else torch.float32,
|
||||||
|
)
|
||||||
elif _is_e5_model:
|
elif _is_e5_model:
|
||||||
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask'])
|
pooled_embeddings = e5_average_pool(
|
||||||
|
out.last_hidden_state, enc["attention_mask"]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = out.last_hidden_state
|
hidden_states = out.last_hidden_state
|
||||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
mask_expanded = (
|
||||||
|
enc["attention_mask"]
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(hidden_states.size())
|
||||||
|
.float()
|
||||||
|
)
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||||
pooled_embeddings = sum_embeddings / sum_mask
|
pooled_embeddings = sum_embeddings / sum_mask
|
||||||
@@ -357,11 +403,17 @@ def create_hnsw_embedding_server(
|
|||||||
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
||||||
|
|
||||||
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
|
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
|
||||||
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
|
print(
|
||||||
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}")
|
f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
|
||||||
|
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}"
|
||||||
|
)
|
||||||
dim_size = final_embeddings.shape[-1]
|
dim_size = final_embeddings.shape[-1]
|
||||||
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy()
|
error_output = torch.zeros(
|
||||||
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}")
|
(batch_size, dim_size), device="cpu", dtype=torch.float32
|
||||||
|
).numpy()
|
||||||
|
print(
|
||||||
|
f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}"
|
||||||
|
)
|
||||||
return error_output
|
return error_output
|
||||||
|
|
||||||
return final_embeddings.cpu().numpy()
|
return final_embeddings.cpu().numpy()
|
||||||
@@ -392,14 +444,18 @@ def create_hnsw_embedding_server(
|
|||||||
request_bytes = msgpack.packb(request_payload)
|
request_bytes = msgpack.packb(request_payload)
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...")
|
print(f"Sending warmup request {i + 1}/3 via ZMQ (MessagePack)...")
|
||||||
socket.send(request_bytes)
|
socket.send(request_bytes)
|
||||||
response_bytes = socket.recv()
|
response_bytes = socket.recv()
|
||||||
|
|
||||||
response_payload = msgpack.unpackb(response_bytes)
|
response_payload = msgpack.unpackb(response_bytes)
|
||||||
dimensions = response_payload[0]
|
dimensions = response_payload[0]
|
||||||
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0
|
embeddings_count = (
|
||||||
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings")
|
dimensions[0] if dimensions and len(dimensions) > 0 else 0
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings"
|
||||||
|
)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
print("Client-side MessagePack ZMQ warmup complete")
|
print("Client-side MessagePack ZMQ warmup complete")
|
||||||
@@ -410,6 +466,7 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
def zmq_server_thread():
|
def zmq_server_thread():
|
||||||
"""ZMQ server thread"""
|
"""ZMQ server thread"""
|
||||||
|
nonlocal passages, model, tokenizer, model_name
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
@@ -428,13 +485,131 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
|
print(f"DEBUG: Raw request_payload: {request_payload}")
|
||||||
|
print(f"DEBUG: request_payload type: {type(request_payload)}")
|
||||||
|
if isinstance(request_payload, list):
|
||||||
|
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
||||||
|
for i, item in enumerate(request_payload):
|
||||||
|
print(
|
||||||
|
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle control messages for meta path and model management
|
||||||
|
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||||
|
if request_payload[0] == "__QUERY_META_PATH__":
|
||||||
|
# Return the current meta path being used by the server
|
||||||
|
current_meta_path = (
|
||||||
|
getattr(passages, "_meta_path", "")
|
||||||
|
if hasattr(passages, "_meta_path")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
response = [current_meta_path]
|
||||||
|
socket.send(msgpack.packb(response))
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif (
|
||||||
|
request_payload[0] == "__UPDATE_META_PATH__"
|
||||||
|
and len(request_payload) >= 2
|
||||||
|
):
|
||||||
|
# Update the server's meta path and reload passages
|
||||||
|
new_meta_path = request_payload[1]
|
||||||
|
try:
|
||||||
|
print(
|
||||||
|
f"INFO: Updating server meta path to: {new_meta_path}"
|
||||||
|
)
|
||||||
|
# Reload passages from the new meta file
|
||||||
|
passages = load_passages_from_metadata(new_meta_path)
|
||||||
|
# Store the meta path for future queries
|
||||||
|
passages._meta_path = new_meta_path
|
||||||
|
response = ["SUCCESS"]
|
||||||
|
print(
|
||||||
|
f"INFO: Successfully updated meta path and reloaded {len(passages)} passages"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to update meta path: {e}")
|
||||||
|
response = ["FAILED", str(e)]
|
||||||
|
socket.send(msgpack.packb(response))
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif request_payload[0] == "__QUERY_MODEL__":
|
||||||
|
# Return the current model being used by the server
|
||||||
|
response = [model_name]
|
||||||
|
socket.send(msgpack.packb(response))
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif (
|
||||||
|
request_payload[0] == "__UPDATE_MODEL__"
|
||||||
|
and len(request_payload) >= 2
|
||||||
|
):
|
||||||
|
# Update the server's embedding model
|
||||||
|
new_model_name = request_payload[1]
|
||||||
|
try:
|
||||||
|
print(
|
||||||
|
f"INFO: Updating server model from {model_name} to: {new_model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up old model to free memory
|
||||||
|
print("INFO: Releasing old model from memory...")
|
||||||
|
old_model = model
|
||||||
|
old_tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Load new tokenizer first
|
||||||
|
print(f"Loading new tokenizer for {new_model_name}...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
new_model_name, use_fast=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load new model
|
||||||
|
print(f"Loading new model {new_model_name}...")
|
||||||
|
model = AutoModel.from_pretrained(new_model_name)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Now safely delete old model after new one is loaded
|
||||||
|
del old_model
|
||||||
|
del old_tokenizer
|
||||||
|
|
||||||
|
# Clear GPU cache if available
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("INFO: Cleared CUDA cache")
|
||||||
|
elif device.type == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
print("INFO: Cleared MPS cache")
|
||||||
|
|
||||||
|
# Update model name
|
||||||
|
model_name = new_model_name
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
print("INFO: Memory cleanup completed")
|
||||||
|
|
||||||
|
response = ["SUCCESS"]
|
||||||
|
print(
|
||||||
|
f"INFO: Successfully updated model to: {new_model_name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to update model: {e}")
|
||||||
|
response = ["FAILED", str(e)]
|
||||||
|
socket.send(msgpack.packb(response))
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Handle distance calculation requests
|
||||||
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list):
|
if (
|
||||||
|
isinstance(request_payload, list)
|
||||||
|
and len(request_payload) == 2
|
||||||
|
and isinstance(request_payload[0], list)
|
||||||
|
and isinstance(request_payload[1], list)
|
||||||
|
):
|
||||||
node_ids = request_payload[0]
|
node_ids = request_payload[0]
|
||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||||
|
|
||||||
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}")
|
print("DEBUG: Distance calculation request received")
|
||||||
|
print(f" Node IDs: {node_ids}")
|
||||||
|
print(f" Query vector dim: {len(query_vector)}")
|
||||||
|
print(f" Passages loaded: {len(passages)}")
|
||||||
|
|
||||||
# Get embeddings for node IDs
|
# Get embeddings for node IDs
|
||||||
texts = []
|
texts = []
|
||||||
@@ -442,12 +617,38 @@ def create_hnsw_embedding_server(
|
|||||||
with lookup_timer.timing():
|
with lookup_timer.timing():
|
||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
print(f"DEBUG: Looking up passage ID {nid}")
|
print(f"DEBUG: Looking up passage ID {nid}")
|
||||||
txtinfo = passages[nid]
|
try:
|
||||||
if txtinfo is None or txtinfo["text"] == "":
|
txtinfo = passages[nid]
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text")
|
if txtinfo is None:
|
||||||
txt = txtinfo["text"]
|
print(
|
||||||
print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}")
|
f"ERROR: Passage with ID {nid} returned None"
|
||||||
texts.append(txt)
|
)
|
||||||
|
print(f"ERROR: txtinfo: {txtinfo}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FATAL: Passage with ID {nid} returned None"
|
||||||
|
)
|
||||||
|
txt = txtinfo[
|
||||||
|
"text"
|
||||||
|
] # Allow empty text to pass through
|
||||||
|
print(
|
||||||
|
f"DEBUG: Found text for ID {nid}, length: {len(txt)}"
|
||||||
|
)
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
print(
|
||||||
|
f"ERROR: Passage ID {nid} not found in passages dict"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"ERROR: Available passage IDs: {list(passages.keys())[:10]}..."
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FATAL: Passage with ID {nid} not found"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"ERROR: Exception looking up passage ID {nid}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
lookup_timer.print_elapsed()
|
lookup_timer.print_elapsed()
|
||||||
|
|
||||||
# Process embeddings in chunks if needed
|
# Process embeddings in chunks if needed
|
||||||
@@ -460,7 +661,9 @@ def create_hnsw_embedding_server(
|
|||||||
chunk_texts = texts[i:end_idx]
|
chunk_texts = texts[i:end_idx]
|
||||||
chunk_ids = node_ids[i:end_idx]
|
chunk_ids = node_ids[i:end_idx]
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
embeddings_chunk = process_batch(
|
||||||
|
chunk_texts, chunk_ids, missing_ids
|
||||||
|
)
|
||||||
all_node_embeddings.append(embeddings_chunk)
|
all_node_embeddings.append(embeddings_chunk)
|
||||||
|
|
||||||
if cuda_available:
|
if cuda_available:
|
||||||
@@ -470,31 +673,60 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
node_embeddings = np.vstack(all_node_embeddings)
|
node_embeddings = np.vstack(all_node_embeddings)
|
||||||
else:
|
else:
|
||||||
node_embeddings = process_batch(texts, node_ids, missing_ids)
|
node_embeddings = process_batch(
|
||||||
|
texts, node_ids, missing_ids
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate distances
|
# Calculate distances
|
||||||
query_tensor = torch.tensor(query_vector, device=device).float()
|
query_tensor = torch.tensor(query_vector, device=device).float()
|
||||||
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float()
|
node_embeddings_tensor = torch.tensor(
|
||||||
|
node_embeddings, device=device
|
||||||
|
).float()
|
||||||
|
|
||||||
calc_timer = DeviceTimer("distance calculation", device)
|
calc_timer = DeviceTimer("distance calculation", device)
|
||||||
with calc_timer.timing():
|
with calc_timer.timing():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if distance_metric == "l2":
|
if distance_metric == "l2":
|
||||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
|
node_embeddings_np = (
|
||||||
query_np = query_tensor.cpu().numpy().astype(np.float32)
|
node_embeddings_tensor.cpu()
|
||||||
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
|
.numpy()
|
||||||
else: # mips or cosine
|
.astype(np.float32)
|
||||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
)
|
||||||
|
query_np = (
|
||||||
|
query_tensor.cpu().numpy().astype(np.float32)
|
||||||
|
)
|
||||||
|
distances = np.sum(
|
||||||
|
np.square(
|
||||||
|
node_embeddings_np - query_np.reshape(1, -1)
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
else: # mips or cosine
|
||||||
|
node_embeddings_np = (
|
||||||
|
node_embeddings_tensor.cpu().numpy()
|
||||||
|
)
|
||||||
query_np = query_tensor.cpu().numpy()
|
query_np = query_tensor.cpu().numpy()
|
||||||
distances = -np.dot(node_embeddings_np, query_np)
|
distances = -np.dot(node_embeddings_np, query_np)
|
||||||
calc_timer.print_elapsed()
|
calc_timer.print_elapsed()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_payload = distances.flatten().tolist()
|
response_payload = distances.flatten().tolist()
|
||||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
response_bytes = msgpack.packb(
|
||||||
print(f"Sending distance response with {len(distances)} distances")
|
[response_payload], use_single_float=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Sending distance response with {len(distances)} distances"
|
||||||
|
)
|
||||||
except Exception as pack_error:
|
except Exception as pack_error:
|
||||||
print(f"Error packing MessagePack distance response: {pack_error}")
|
print(
|
||||||
|
f"ERROR: Error packing MessagePack distance response: {pack_error}"
|
||||||
|
)
|
||||||
|
print(f"ERROR: distances shape: {distances.shape}")
|
||||||
|
print(f"ERROR: distances dtype: {distances.dtype}")
|
||||||
|
print(f"ERROR: distances content: {distances}")
|
||||||
|
print(f"ERROR: node_ids: {node_ids}")
|
||||||
|
print(f"ERROR: query_vector shape: {query_vector.shape}")
|
||||||
|
# Still return empty for now but with full error info
|
||||||
response_bytes = msgpack.packb([[]])
|
response_bytes = msgpack.packb([[]])
|
||||||
|
|
||||||
socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
@@ -504,12 +736,20 @@ def create_hnsw_embedding_server(
|
|||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds")
|
print(
|
||||||
|
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Standard embedding request
|
# Standard embedding request
|
||||||
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list):
|
if (
|
||||||
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}")
|
not isinstance(request_payload, list)
|
||||||
|
or len(request_payload) != 1
|
||||||
|
or not isinstance(request_payload[0], list)
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}"
|
||||||
|
)
|
||||||
socket.send(msgpack.packb([[], []]))
|
socket.send(msgpack.packb([[], []]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -529,11 +769,15 @@ def create_hnsw_embedding_server(
|
|||||||
try:
|
try:
|
||||||
txtinfo = passages[nid]
|
txtinfo = passages[nid]
|
||||||
if txtinfo is None or txtinfo["text"] == "":
|
if txtinfo is None or txtinfo["text"] == "":
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
raise RuntimeError(
|
||||||
|
f"FATAL: Passage with ID {nid} not found - failing fast"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
txt = txtinfo["text"]
|
txt = txtinfo["text"]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
|
raise RuntimeError(
|
||||||
|
f"FATAL: Passage with ID {nid} not found - failing fast"
|
||||||
|
)
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
lookup_timer.print_elapsed()
|
lookup_timer.print_elapsed()
|
||||||
|
|
||||||
@@ -542,20 +786,28 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
# Process in chunks
|
# Process in chunks
|
||||||
total_size = len(texts)
|
total_size = len(texts)
|
||||||
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
print(
|
||||||
|
f"Total batch size: {total_size}, max_batch_size: {max_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
if total_size > max_batch_size:
|
||||||
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
print(
|
||||||
|
f"Splitting batch of size {total_size} into chunks of {max_batch_size}"
|
||||||
|
)
|
||||||
for i in range(0, total_size, max_batch_size):
|
for i in range(0, total_size, max_batch_size):
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
end_idx = min(i + max_batch_size, total_size)
|
||||||
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
print(
|
||||||
|
f"Processing chunk {i // max_batch_size + 1}/{(total_size + max_batch_size - 1) // max_batch_size}: items {i} to {end_idx - 1}"
|
||||||
|
)
|
||||||
|
|
||||||
chunk_texts = texts[i:end_idx]
|
chunk_texts = texts[i:end_idx]
|
||||||
chunk_ids = node_ids[i:end_idx]
|
chunk_ids = node_ids[i:end_idx]
|
||||||
|
|
||||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
embeddings_chunk = process_batch(
|
||||||
|
chunk_texts, chunk_ids, missing_ids
|
||||||
|
)
|
||||||
all_embeddings.append(embeddings_chunk)
|
all_embeddings.append(embeddings_chunk)
|
||||||
|
|
||||||
if cuda_available:
|
if cuda_available:
|
||||||
@@ -571,22 +823,30 @@ def create_hnsw_embedding_server(
|
|||||||
# Serialization and response
|
# Serialization and response
|
||||||
ser_start = time.time()
|
ser_start = time.time()
|
||||||
|
|
||||||
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}")
|
print(
|
||||||
|
f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}"
|
||||||
|
)
|
||||||
if np.isnan(hidden).any() or np.isinf(hidden).any():
|
if np.isnan(hidden).any() or np.isinf(hidden).any():
|
||||||
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
|
print(
|
||||||
f"Requested IDs (sample): {node_ids[:5]}...{RESET}")
|
f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
|
||||||
|
f"Requested IDs (sample): {node_ids[:5]}...{RESET}"
|
||||||
|
)
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32)
|
hidden_contiguous_f32 = np.ascontiguousarray(
|
||||||
|
hidden, dtype=np.float32
|
||||||
|
)
|
||||||
response_payload = [
|
response_payload = [
|
||||||
list(hidden_contiguous_f32.shape),
|
list(hidden_contiguous_f32.shape),
|
||||||
hidden_contiguous_f32.flatten().tolist()
|
hidden_contiguous_f32.flatten().tolist(),
|
||||||
]
|
]
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
response_bytes = msgpack.packb(
|
||||||
|
response_payload, use_single_float=True
|
||||||
|
)
|
||||||
except Exception as pack_error:
|
except Exception as pack_error:
|
||||||
print(f"Error packing MessagePack response: {pack_error}")
|
print(f"Error packing MessagePack response: {pack_error}")
|
||||||
response_bytes = msgpack.packb([[], []])
|
response_bytes = msgpack.packb([[], []])
|
||||||
|
|
||||||
socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
ser_end = time.time()
|
ser_end = time.time()
|
||||||
@@ -606,6 +866,7 @@ def create_hnsw_embedding_server(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in ZMQ server loop: {e}")
|
print(f"Error in ZMQ server loop: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
try:
|
try:
|
||||||
socket.send(msgpack.packb([[], []]))
|
socket.send(msgpack.packb([[], []]))
|
||||||
@@ -634,16 +895,40 @@ def create_hnsw_embedding_server(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
parser.add_argument(
|
||||||
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings")
|
"--passages-file",
|
||||||
|
type=str,
|
||||||
|
help="JSON file containing passage ID to text mapping",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embeddings-file",
|
||||||
|
type=str,
|
||||||
|
help="Pickle file containing pre-computed embeddings",
|
||||||
|
)
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
parser.add_argument(
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
"--max-batch-size",
|
||||||
help="Embedding model name")
|
type=int,
|
||||||
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
|
default=128,
|
||||||
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
|
help="Maximum batch size before splitting",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--custom-max-length",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Override model's default max sequence length",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric", type=str, default="mips", help="Distance metric to use"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,169 @@
|
|||||||
|
|
||||||
import os
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import atexit
|
import atexit
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import select
|
||||||
|
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
def _check_port(port: int) -> bool:
|
||||||
"""Check if a port is in use"""
|
"""Check if a port is in use"""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the existing server on the port is using the correct meta file.
|
||||||
|
Returns True if the server has the right meta path, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||||
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
|
# Send a special control message to query the server's meta path
|
||||||
|
control_request = ["__QUERY_META_PATH__"]
|
||||||
|
request_bytes = msgpack.packb(control_request)
|
||||||
|
socket.send(request_bytes)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
# Check if the response contains the meta path and if it matches
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
server_meta_path = response[0]
|
||||||
|
# Normalize paths for comparison
|
||||||
|
expected_path = Path(expected_meta_path).resolve()
|
||||||
|
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
||||||
|
return server_path == expected_path
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Send a control message to update the server's meta path.
|
||||||
|
Returns True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||||
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
|
# Send a control message to update the meta path
|
||||||
|
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
||||||
|
request_bytes = msgpack.packb(control_request)
|
||||||
|
socket.send(request_bytes)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
# Check if the update was successful
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
return response[0] == "SUCCESS"
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_server_model(port: int, expected_model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the existing server on the port is using the correct embedding model.
|
||||||
|
Returns True if the server has the right model, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||||
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
|
# Send a special control message to query the server's model
|
||||||
|
control_request = ["__QUERY_MODEL__"]
|
||||||
|
request_bytes = msgpack.packb(control_request)
|
||||||
|
socket.send(request_bytes)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
# Check if the response contains the model name and if it matches
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
server_model = response[0]
|
||||||
|
return server_model == expected_model
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WARNING: Could not query server model on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _update_server_model(port: int, new_model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Send a control message to update the server's embedding model.
|
||||||
|
Returns True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
||||||
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
|
# Send a control message to update the model
|
||||||
|
control_request = ["__UPDATE_MODEL__", new_model]
|
||||||
|
request_bytes = msgpack.packb(control_request)
|
||||||
|
socket.send(request_bytes)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
# Check if the update was successful
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
return response[0] == "SUCCESS"
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Could not update server model on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
"""
|
"""
|
||||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend_module_name: str):
|
def __init__(self, backend_module_name: str):
|
||||||
"""
|
"""
|
||||||
Initializes the manager for a specific backend.
|
Initializes the manager for a specific backend.
|
||||||
@@ -44,21 +190,119 @@ class EmbeddingServerManager:
|
|||||||
bool: True if the server is started successfully or already running, False otherwise.
|
bool: True if the server is started successfully or already running, False otherwise.
|
||||||
"""
|
"""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
if self.server_process and self.server_process.poll() is None:
|
||||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
# Even if we have a running process, check if model/meta path match
|
||||||
return True
|
if self.server_port is not None:
|
||||||
|
port_in_use = _check_port(self.server_port)
|
||||||
|
if port_in_use:
|
||||||
|
print(
|
||||||
|
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check model compatibility
|
||||||
|
model_matches = _check_server_model(self.server_port, model_name)
|
||||||
|
if not model_matches:
|
||||||
|
print(
|
||||||
|
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||||
|
)
|
||||||
|
if not _update_server_model(self.server_port, model_name):
|
||||||
|
print(
|
||||||
|
"❌ Failed to update existing server model. Restarting server..."
|
||||||
|
)
|
||||||
|
self.stop_server()
|
||||||
|
# Continue to start new server below
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"✅ Successfully updated existing server model to: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also check meta path if provided
|
||||||
|
passages_file = kwargs.get("passages_file")
|
||||||
|
if passages_file and str(passages_file).endswith(
|
||||||
|
".meta.json"
|
||||||
|
):
|
||||||
|
meta_matches = _check_server_meta_path(
|
||||||
|
self.server_port, str(passages_file)
|
||||||
|
)
|
||||||
|
if not meta_matches:
|
||||||
|
print("⚠️ Updating meta path to: {passages_file}")
|
||||||
|
_update_server_meta_path(
|
||||||
|
self.server_port, str(passages_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"✅ Existing server already using correct model: {model_name}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# Server process exists but port not responding - restart
|
||||||
|
print("⚠️ Server process exists but not responding. Restarting...")
|
||||||
|
self.stop_server()
|
||||||
|
# Continue to start new server below
|
||||||
|
else:
|
||||||
|
# No port stored - restart
|
||||||
|
print("⚠️ No port information stored. Restarting server...")
|
||||||
|
self.stop_server()
|
||||||
|
# Continue to start new server below
|
||||||
|
|
||||||
if _check_port(port):
|
if _check_port(port):
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running.")
|
# Port is in use, check if it's using the correct meta file and model
|
||||||
|
passages_file = kwargs.get("passages_file")
|
||||||
|
|
||||||
|
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
||||||
|
|
||||||
|
# Check model compatibility first
|
||||||
|
model_matches = _check_server_model(port, model_name)
|
||||||
|
if not model_matches:
|
||||||
|
print(
|
||||||
|
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||||
|
)
|
||||||
|
if not _update_server_model(port, model_name):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||||
|
)
|
||||||
|
print(f"✅ Successfully updated server model to: {model_name}")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check meta path compatibility if provided
|
||||||
|
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||||
|
meta_matches = _check_server_meta_path(port, str(passages_file))
|
||||||
|
if not meta_matches:
|
||||||
|
print(
|
||||||
|
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
||||||
|
)
|
||||||
|
if not _update_server_meta_path(port, str(passages_file)):
|
||||||
|
raise RuntimeError(
|
||||||
|
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"✅ Successfully updated server meta path to: {passages_file}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Server on port {port} is compatible and ready to use.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...")
|
print(
|
||||||
|
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command = [
|
command = [
|
||||||
sys.executable,
|
sys.executable,
|
||||||
"-m", self.backend_module_name,
|
"-m",
|
||||||
"--zmq-port", str(port),
|
self.backend_module_name,
|
||||||
"--model-name", model_name
|
"--zmq-port",
|
||||||
|
str(port),
|
||||||
|
"--model-name",
|
||||||
|
model_name,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add extra arguments for specific backends
|
# Add extra arguments for specific backends
|
||||||
@@ -76,9 +320,9 @@ class EmbeddingServerManager:
|
|||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
||||||
text=True,
|
text=True,
|
||||||
encoding='utf-8',
|
encoding="utf-8",
|
||||||
bufsize=1, # Line buffered
|
bufsize=1, # Line buffered
|
||||||
universal_newlines=True
|
universal_newlines=True,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||||
@@ -86,17 +330,21 @@ class EmbeddingServerManager:
|
|||||||
max_wait, wait_interval = 120, 0.5
|
max_wait, wait_interval = 120, 0.5
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
if _check_port(port):
|
if _check_port(port):
|
||||||
print(f"✅ Embedding server is up and ready for this session.")
|
print("✅ Embedding server is up and ready for this session.")
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||||
log_thread.start()
|
log_thread.start()
|
||||||
return True
|
return True
|
||||||
if self.server_process.poll() is not None:
|
if self.server_process.poll() is not None:
|
||||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
print(
|
||||||
|
"❌ ERROR: Server process terminated unexpectedly during startup."
|
||||||
|
)
|
||||||
self._print_recent_output()
|
self._print_recent_output()
|
||||||
return False
|
return False
|
||||||
time.sleep(wait_interval)
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
print(
|
||||||
|
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
||||||
|
)
|
||||||
self.stop_server()
|
self.stop_server()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -110,8 +358,7 @@ class EmbeddingServerManager:
|
|||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
# Read any available output
|
# Read any available output
|
||||||
import select
|
|
||||||
import sys
|
|
||||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||||
output = self.server_process.stdout.read()
|
output = self.server_process.stdout.read()
|
||||||
if output:
|
if output:
|
||||||
@@ -129,19 +376,25 @@ class EmbeddingServerManager:
|
|||||||
line = self.server_process.stdout.readline()
|
line = self.server_process.stdout.readline()
|
||||||
if not line:
|
if not line:
|
||||||
break
|
break
|
||||||
print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True)
|
print(
|
||||||
|
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Log monitor error: {e}")
|
print(f"Log monitor error: {e}")
|
||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
"""Stops the embedding server process if it's running."""
|
"""Stops the embedding server process if it's running."""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
if self.server_process and self.server_process.poll() is None:
|
||||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
print(
|
||||||
|
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
||||||
|
)
|
||||||
self.server_process.terminate()
|
self.server_process.terminate()
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=5)
|
self.server_process.wait(timeout=5)
|
||||||
print("INFO: Server process terminated.")
|
print("INFO: Server process terminated.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
print(
|
||||||
|
"WARNING: Server process did not terminate gracefully, killing it."
|
||||||
|
)
|
||||||
self.server_process.kill()
|
self.server_process.kill()
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ dependencies = [
|
|||||||
"llama-index-node-parser-docling",
|
"llama-index-node-parser-docling",
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
Reference in New Issue
Block a user