diff --git a/.github/workflows/build-reusable.yml b/.github/workflows/build-reusable.yml index 673c074..8519c8d 100644 --- a/.github/workflows/build-reusable.yml +++ b/.github/workflows/build-reusable.yml @@ -263,25 +263,25 @@ jobs: # Activate virtual environment source .venv/bin/activate || source .venv/Scripts/activate - # Run all tests with timeout to debug Python 3.13 hanging issue - # Using timeout with INT signal to get full traceback when hanging (Linux only) - if [[ "${{ matrix.python }}" == "3.13" ]] && [[ "$RUNNER_OS" == "Linux" ]]; then - echo "Running tests with timeout for Python 3.13 debugging (Linux)..." - timeout --signal=INT 120 pytest tests/ --full-trace -v || { + # Run all tests with timeout on Linux to prevent hanging + if [[ "$RUNNER_OS" == "Linux" ]]; then + echo "Running tests with timeout (Linux)..." + timeout --signal=INT 180 pytest tests/ -v || { EXIT_CODE=$? if [ $EXIT_CODE -eq 124 ]; then - echo "⚠️ Tests timed out after 120 seconds - likely hanging during collection" - echo "This is a known issue with Python 3.13 - see traceback above for details" + echo "⚠️ Tests timed out after 180 seconds - likely process cleanup issue" + echo "Check for lingering ZMQ connections or child processes" + # Try to clean up any leftover processes + pkill -TERM -P $$ || true + sleep 1 + pkill -KILL -P $$ || true fi exit $EXIT_CODE } - elif [[ "${{ matrix.python }}" == "3.13" ]]; then - # For macOS/Windows, run with verbose output but no timeout - echo "Running tests for Python 3.13 (no timeout on $RUNNER_OS)..." - pytest tests/ --full-trace -v else - # Normal test run for other Python versions - pytest tests/ + # For macOS/Windows, run without GNU timeout + echo "Running tests ($RUNNER_OS)..." + pytest tests/ -v fi - name: Run sanity checks (optional) diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py index b566ae6..e345568 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py @@ -100,6 +100,7 @@ def create_diskann_embedding_server( socket = context.socket( zmq.REP ) # REP socket for both BaseSearcher and DiskANN C++ REQ clients + socket.setsockopt(zmq.LINGER, 0) # Don't block on close socket.bind(f"tcp://*:{zmq_port}") logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}") diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index bf36883..df0e44b 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -92,6 +92,7 @@ def create_hnsw_embedding_server( """ZMQ server thread""" context = zmq.Context() socket = context.socket(zmq.REP) + socket.setsockopt(zmq.LINGER, 0) # Don't block on close socket.bind(f"tcp://*:{zmq_port}") logger.info(f"HNSW ZMQ server listening on port {zmq_port}") diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index daae62d..15c865e 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -87,21 +87,23 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) # Connect to embedding server context = zmq.Context() socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.LINGER, 0) # Don't block on close socket.connect(f"tcp://localhost:{port}") - # Send chunks to server for embedding computation - request = chunks - socket.send(msgpack.packb(request)) + try: + # Send chunks to server for embedding computation + request = chunks + socket.send(msgpack.packb(request)) - # Receive embeddings from server - response = socket.recv() - embeddings_list = msgpack.unpackb(response) + # Receive embeddings from server + response = socket.recv() + embeddings_list = msgpack.unpackb(response) - # Convert back to numpy array - embeddings = np.array(embeddings_list, dtype=np.float32) - - socket.close() - context.term() + # Convert back to numpy array + embeddings = np.array(embeddings_list, dtype=np.float32) + finally: + socket.close(linger=0) + context.term() return embeddings diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index ff368c8..5b9b80a 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -132,10 +132,13 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): import msgpack import zmq + context = None + socket = None try: context = zmq.Context() socket = context.socket(zmq.REQ) socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout + socket.setsockopt(zmq.LINGER, 0) # Don't block on close socket.connect(f"tcp://localhost:{zmq_port}") # Send embedding request @@ -147,9 +150,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): response_bytes = socket.recv() response = msgpack.unpackb(response_bytes) - socket.close() - context.term() - # Convert response to numpy array if isinstance(response, list) and len(response) > 0: return np.array(response, dtype=np.float32) @@ -158,6 +158,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): except Exception as e: raise RuntimeError(f"Failed to compute embeddings via server: {e}") + finally: + if socket: + socket.close(linger=0) + if context: + context.term() @abstractmethod def search( diff --git a/pyproject.toml b/pyproject.toml index 075778f..d03feba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,8 +60,9 @@ dev = [ test = [ "pytest>=8.3.0", # Minimum version for Python 3.13 support -"pytest-timeout>=2.3", -"anyio>=4.0", # For async test support (includes pytest plugin) + "pytest-timeout>=2.3", + "anyio>=4.0", # For async test support (includes pytest plugin) + "psutil>=5.9.0", # For process cleanup in tests "llama-index-core>=0.12.0", "llama-index-readers-file>=0.4.0", "python-dotenv>=1.0.0", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a1e45a5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,99 @@ +"""Global test configuration and cleanup fixtures.""" + +import os +import signal +import time +from collections.abc import Generator + +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def global_test_cleanup() -> Generator: + """Global cleanup fixture that runs after all tests. + + This ensures all ZMQ connections and child processes are properly cleaned up, + preventing the test runner from hanging on exit. + """ + yield + + # Cleanup after all tests + try: + import zmq + + # Set a very short linger on any remaining contexts + # This prevents blocking on context termination + ctx = zmq.Context.instance() + ctx.linger = 0 + except Exception: + pass + + # Kill any leftover child processes + try: + import psutil + + current_process = psutil.Process() + children = current_process.children(recursive=True) + + if children: + print(f"\n⚠️ Cleaning up {len(children)} leftover child processes...") + + # First try to terminate gracefully + for child in children: + try: + child.terminate() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + # Wait a bit for processes to terminate + gone, alive = psutil.wait_procs(children, timeout=2) + + # Force kill any remaining processes + for child in alive: + try: + print(f" Force killing process {child.pid} ({child.name()})") + child.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except ImportError: + # psutil not installed, try basic process cleanup + try: + # Send SIGTERM to all child processes + os.killpg(os.getpgid(os.getpid()), signal.SIGTERM) + except Exception: + pass + except Exception as e: + print(f"Warning: Error during process cleanup: {e}") + + # List any remaining threads (for debugging) + try: + import threading + + threads = [t for t in threading.enumerate() if t is not threading.main_thread()] + if threads: + print(f"\n⚠️ {len(threads)} non-main threads still running:") + for t in threads: + print(f" - {t.name} (daemon={t.daemon})") + except Exception: + pass + + +@pytest.fixture(autouse=True) +def cleanup_after_each_test(): + """Cleanup after each test to prevent resource leaks.""" + yield + + # Force garbage collection to trigger any __del__ methods + import gc + + gc.collect() + + # Give a moment for async cleanup + time.sleep(0.1) + + +def pytest_configure(config): + """Configure pytest with better timeout handling.""" + # Set default timeout method to thread if not specified + if not config.getoption("--timeout-method", None): + config.option.timeout_method = "thread"