fix larger file read and add faq
This commit is contained in:
@@ -70,7 +70,7 @@ async def main():
|
|||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
# query = "What is the Off-policy training in RL?"
|
# query = "What is the Off-policy training in RL?"
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
|
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -76,8 +76,8 @@ class EmbeddingServerManager:
|
|||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=subprocess.PIPE,
|
# stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
# stderr=subprocess.PIPE,
|
||||||
text=True,
|
text=True,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
@@ -246,7 +246,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
||||||
complexity = kwargs.get("complexity", 32)
|
complexity = kwargs.get("complexity", 256)
|
||||||
beam_width = kwargs.get("beam_width", 4)
|
beam_width = kwargs.get("beam_width", 4)
|
||||||
|
|
||||||
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
||||||
@@ -259,7 +259,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
if recompute_beighbor_embeddings:
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||||
zmq_port = kwargs.get("zmq_port", 5555)
|
zmq_port = kwargs.get("zmq_port", 6666)
|
||||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||||
|
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
||||||
|
|||||||
113
tests/sanity_checks/debug_zmq_issue.py
Normal file
113
tests/sanity_checks/debug_zmq_issue.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import sys
|
||||||
|
sys.path.append('packages/leann-backend-diskann')
|
||||||
|
from leann_backend_diskann import embedding_pb2
|
||||||
|
|
||||||
|
def test_zmq_with_same_model():
|
||||||
|
print("=== Testing ZMQ with same model as main_cli_example.py ===")
|
||||||
|
|
||||||
|
# Test the exact same model that main_cli_example.py uses
|
||||||
|
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
|
||||||
|
# Start server with the same model
|
||||||
|
import subprocess
|
||||||
|
server_cmd = [
|
||||||
|
sys.executable, "-m",
|
||||||
|
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||||
|
"--zmq-port", "5556", # Use different port to avoid conflicts
|
||||||
|
"--model-name", model_name
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Starting server with command: {' '.join(server_cmd)}")
|
||||||
|
server_process = subprocess.Popen(
|
||||||
|
server_cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
print("Waiting for server to start...")
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# Check if server is running
|
||||||
|
if server_process.poll() is not None:
|
||||||
|
stdout, stderr = server_process.communicate()
|
||||||
|
print(f"Server failed to start. stdout: {stdout}")
|
||||||
|
print(f"Server failed to start. stderr: {stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Server started with PID: {server_process.pid}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test client
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.connect("tcp://127.0.0.1:5556")
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||||
|
|
||||||
|
# Create request with same format as C++
|
||||||
|
request = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
|
||||||
|
|
||||||
|
print(f"Sending request with {len(request.node_ids)} node IDs...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
socket.send(request.SerializeToString())
|
||||||
|
|
||||||
|
# Receive response
|
||||||
|
response_data = socket.recv()
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"Received response in {end_time - start_time:.3f} seconds")
|
||||||
|
print(f"Response size: {len(response_data)} bytes")
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
response = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
response.ParseFromString(response_data)
|
||||||
|
|
||||||
|
print(f"Response dimensions: {list(response.dimensions)}")
|
||||||
|
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
|
||||||
|
print(f"Missing IDs: {list(response.missing_ids)}")
|
||||||
|
|
||||||
|
# Calculate expected size
|
||||||
|
if len(response.dimensions) == 2:
|
||||||
|
batch_size = response.dimensions[0]
|
||||||
|
embedding_dim = response.dimensions[1]
|
||||||
|
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
|
||||||
|
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
|
||||||
|
|
||||||
|
if len(response.embeddings_data) == expected_bytes:
|
||||||
|
print("✅ Response format is correct!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Response format mismatch!")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("❌ Invalid response dimensions!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during ZMQ test: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
server_process.terminate()
|
||||||
|
server_process.wait()
|
||||||
|
print("Server terminated")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_zmq_with_same_model()
|
||||||
|
if success:
|
||||||
|
print("\n✅ ZMQ communication test passed!")
|
||||||
|
else:
|
||||||
|
print("\n❌ ZMQ communication test failed!")
|
||||||
Reference in New Issue
Block a user