diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index 1cfb4d2..d489970 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -74,8 +74,7 @@ async def main(): query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" print(f"You: {query}") - - chat_response = chat.ask(query, top_k=20) + chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever") print(f"Leann: {chat_response}") if __name__ == "__main__": diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index b0b076b..8af4046 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -76,8 +76,8 @@ class EmbeddingServerManager: self.server_process = subprocess.Popen( command, cwd=project_root, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, text=True, encoding='utf-8' ) @@ -242,7 +242,7 @@ class DiskannSearcher(LeannBackendSearcherInterface): raise def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]: - complexity = kwargs.get("complexity", 32) + complexity = kwargs.get("complexity", 256) beam_width = kwargs.get("beam_width", 4) USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False) @@ -255,7 +255,7 @@ class DiskannSearcher(LeannBackendSearcherInterface): if recompute_beighbor_embeddings: print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") - zmq_port = kwargs.get("zmq_port", 5555) + zmq_port = kwargs.get("zmq_port", 6666) embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2") if not self.embedding_server_manager.start_server(zmq_port, embedding_model): diff --git a/tests/sanity_checks/debug_zmq_issue.py b/tests/sanity_checks/debug_zmq_issue.py new file mode 100644 index 0000000..d2d90c1 --- /dev/null +++ b/tests/sanity_checks/debug_zmq_issue.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +""" +Debug script to test ZMQ communication with the exact same setup as main_cli_example.py +""" + +import zmq +import time +import threading +import sys +sys.path.append('packages/leann-backend-diskann') +from leann_backend_diskann import embedding_pb2 + +def test_zmq_with_same_model(): + print("=== Testing ZMQ with same model as main_cli_example.py ===") + + # Test the exact same model that main_cli_example.py uses + model_name = "sentence-transformers/all-mpnet-base-v2" + + # Start server with the same model + import subprocess + server_cmd = [ + sys.executable, "-m", + "packages.leann-backend-diskann.leann_backend_diskann.embedding_server", + "--zmq-port", "5556", # Use different port to avoid conflicts + "--model-name", model_name + ] + + print(f"Starting server with command: {' '.join(server_cmd)}") + server_process = subprocess.Popen( + server_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + # Wait for server to start + print("Waiting for server to start...") + time.sleep(10) + + # Check if server is running + if server_process.poll() is not None: + stdout, stderr = server_process.communicate() + print(f"Server failed to start. stdout: {stdout}") + print(f"Server failed to start. stderr: {stderr}") + return False + + print(f"Server started with PID: {server_process.pid}") + + try: + # Test client + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://127.0.0.1:5556") + socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++ + socket.setsockopt(zmq.SNDTIMEO, 30000) + + # Create request with same format as C++ + request = embedding_pb2.NodeEmbeddingRequest() + request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs + + print(f"Sending request with {len(request.node_ids)} node IDs...") + start_time = time.time() + + # Send request + socket.send(request.SerializeToString()) + + # Receive response + response_data = socket.recv() + end_time = time.time() + + print(f"Received response in {end_time - start_time:.3f} seconds") + print(f"Response size: {len(response_data)} bytes") + + # Parse response + response = embedding_pb2.NodeEmbeddingResponse() + response.ParseFromString(response_data) + + print(f"Response dimensions: {list(response.dimensions)}") + print(f"Embeddings data size: {len(response.embeddings_data)} bytes") + print(f"Missing IDs: {list(response.missing_ids)}") + + # Calculate expected size + if len(response.dimensions) == 2: + batch_size = response.dimensions[0] + embedding_dim = response.dimensions[1] + expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float + print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}") + + if len(response.embeddings_data) == expected_bytes: + print("✅ Response format is correct!") + return True + else: + print("❌ Response format mismatch!") + return False + else: + print("❌ Invalid response dimensions!") + return False + + except Exception as e: + print(f"❌ Error during ZMQ test: {e}") + return False + finally: + # Clean up + server_process.terminate() + server_process.wait() + print("Server terminated") + +if __name__ == "__main__": + success = test_zmq_with_same_model() + if success: + print("\n✅ ZMQ communication test passed!") + else: + print("\n❌ ZMQ communication test failed!") \ No newline at end of file