Compare commits

...

7 Commits

5 changed files with 282 additions and 60 deletions

View File

@@ -71,6 +71,8 @@ source .venv/bin/activate
uv pip install leann uv pip install leann
``` ```
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups).
<details> <details>
<summary> <summary>
<strong>🔧 Build from Source (Recommended for development)</strong> <strong>🔧 Build from Source (Recommended for development)</strong>

View File

@@ -259,24 +259,80 @@ Every configuration choice involves trade-offs:
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed. The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Deep Dive: Critical Configuration Decisions ## Low-resource setups
### When to Disable Recomputation If you dont have a local GPU or builds/searches are too slow, use one or more of the options below.
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements: ### 1) Use OpenAI embeddings (no local compute)
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
```bash ```bash
--no-recompute # Disable selective recomputation export OPENAI_API_KEY=sk-...
# Build with OpenAI embeddings
leann build my-index \
--embedding-mode openai \
--embedding-model text-embedding-3-small
# Search with OpenAI embeddings (recompute at query time)
leann search my-index "your query" \
--recompute-embeddings
``` ```
**Trade-offs**: ### 2) Run remote builds with SkyPilot (cloud GPU)
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
**Disable when**: Offload embedding generation and index building to a GPU VM using SkyPilot. A template is provided at `sky/leann-build.yaml`.
- You have abundant storage and memory
- Need extremely low latency (< 100ms) ```bash
- Running a read-heavy workload where storage cost is acceptable # One-time: install and configure SkyPilot
pip install skypilot
sky launch -c leann-gpu sky/leann-build.yaml
# Build remotely (template installs uv + leann CLI)
sky exec leann-gpu -- "leann build my-index --docs ~/leann-data --backend hnsw --complexity 64 --graph-degree 32"
```
Details: see “Running Builds on SkyPilot (Optional)” below.
### 3) Disable recomputation to trade storage for speed
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
```bash
# Build without recomputation (HNSW requires non-compact in this mode)
leann build my-index --no-recompute --no-compact
# Search without recomputation
leann search my-index "your query" --no-recompute
```
Trade-offs: lower query-time latency, but significantly higher storage usage.
## Running Builds on SkyPilot (Optional)
You can offload embedding generation and index building to a cloud GPU VM using SkyPilot, without changing any LEANN code. This is useful when your local machine lacks a GPU or you want faster throughput.
### Quick Start
1) Install SkyPilot by following their docs (`pip install skypilot`), then configure cloud credentials.
2) Use the provided SkyPilot template:
```bash
sky launch -c leann-gpu sky/leann-build.yaml
```
3) On the remote, either put your data under the mounted path or adjust `file_mounts` in `sky/leann-build.yaml`. Then run the LEANN build:
```bash
sky exec leann-gpu -- "leann build my-index --docs ~/leann-data --backend hnsw --complexity 64 --graph-degree 32"
```
Notes:
- The template installs `uv` and the `leann` CLI globally on the remote instance.
- Change the `accelerators` and `cloud` settings in `sky/leann-build.yaml` to match your budget/availability (e.g., `A10G:1`, `A100:1`, or CPU-only if you prefer).
- You can also build with `diskann` by switching `--backend diskann`.
## Further Reading ## Further Reading

View File

@@ -95,6 +95,8 @@ def create_hnsw_embedding_server(
passage_sources.append(source_copy) passage_sources.append(source_copy)
passages = PassageManager(passage_sources) passages = PassageManager(passage_sources)
# Use index dimensions from metadata for shaping fallback responses
embedding_dim: int = int(meta.get("dimensions", 0))
logger.info( logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
) )
@@ -109,6 +111,9 @@ def create_hnsw_embedding_server(
socket.setsockopt(zmq.RCVTIMEO, 300000) socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000) socket.setsockopt(zmq.SNDTIMEO, 300000)
# Track last request type for safe fallback responses on exceptions
last_request_type = "unknown" # one of: 'text', 'distance', 'embedding', 'unknown'
last_request_length = 0
while True: while True:
try: try:
message_bytes = socket.recv() message_bytes = socket.recv()
@@ -121,6 +126,8 @@ def create_hnsw_embedding_server(
if isinstance(request_payload, list) and len(request_payload) > 0: if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings) # Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload): if all(isinstance(item, str) for item in request_payload):
last_request_type = "text"
last_request_length = len(request_payload)
logger.info( logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode" f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
) )
@@ -145,43 +152,66 @@ def create_hnsw_embedding_server(
): ):
node_ids = request_payload[0] node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32) query_vector = np.array(request_payload[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received") logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}") logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}") logger.debug(f" Query vector dim: {len(query_vector)}")
# Get embeddings for node IDs # Get embeddings for node IDs, tolerate missing IDs
texts = [] texts: list[str] = []
for nid in node_ids: found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try: try:
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data.get("text", "")
texts.append(txt) if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError: except KeyError:
logger.error(f"Passage ID {nid} not found") logger.error(f"Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e: except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings # Prepare full-length response distances with safe fallbacks
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) large_distance = 1e9
logger.info( response_distances = [large_distance] * len(node_ids)
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
if texts:
try:
# Process embeddings only for found indices
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Calculate distances for found embeddings only
if distance_metric == "l2":
partial_distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial_distances = -np.dot(embeddings, query_vector)
# Place computed distances back into the full response array
for pos, dval in zip(
found_indices, partial_distances.flatten().tolist()
):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(
f"Distance computation error, falling back to large distances: {e}"
)
# Always reply with exactly len(node_ids) distances
response_bytes = msgpack.packb([response_distances], use_single_float=True)
logger.debug(
f"Sending distance response with {len(response_distances)} distances (found={len(found_indices)})"
) )
# Calculate distances
if distance_metric == "l2":
distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
distances = -np.dot(embeddings, query_vector)
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
logger.debug(f"Sending distance response with {len(distances)} distances")
socket.send(response_bytes) socket.send(response_bytes)
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
@@ -201,40 +231,61 @@ def create_hnsw_embedding_server(
node_ids = request_payload[0] node_ids = request_payload[0]
logger.debug(f"Request for {len(node_ids)} node embeddings") logger.debug(f"Request for {len(node_ids)} node embeddings")
last_request_type = "embedding"
last_request_length = len(node_ids)
# Look up texts by node IDs # Allocate output buffer (B, D) and fill with zeros for robustness
texts = [] if embedding_dim <= 0:
for nid in node_ids: logger.error("Embedding dimension unknown; cannot serve embedding request")
dims = [0, 0]
data = []
else:
dims = [len(node_ids), embedding_dim]
data = [0.0] * (dims[0] * dims[1])
# Look up texts by node IDs; compute embeddings where available
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try: try:
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data.get("text", "")
if not txt: if isinstance(txt, str) and len(txt) > 0:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") texts.append(txt)
texts.append(txt) found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError: except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") logger.error(f"Passage with ID {nid} not found")
except Exception as e: except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings if texts:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) try:
logger.info( # Process embeddings for found texts only
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
) logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Serialization and response if np.isnan(embeddings).any() or np.isinf(embeddings).any():
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): logger.error(
logger.error( f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." )
) dims = [0, embedding_dim]
raise AssertionError() data = []
else:
# Copy computed embeddings into the correct positions
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
data[start:end] = flat[j * embedding_dim : (j + 1) * embedding_dim]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) response_payload = [dims, data]
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True) response_bytes = msgpack.packb(response_payload, use_single_float=True)
socket.send(response_bytes) socket.send(response_bytes)
@@ -249,7 +300,22 @@ def create_hnsw_embedding_server(
import traceback import traceback
traceback.print_exc() traceback.print_exc()
socket.send(msgpack.packb([[], []])) # Fallback to a safe, minimal-structure response to avoid client crashes
if last_request_type == "distance":
# Return a vector of large distances with the expected length
fallback_len = max(0, int(last_request_length))
large_distance = 1e9
safe_response = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
# Return an empty embedding block with known dimension if available
if embedding_dim > 0:
safe_response = [[0, embedding_dim], []]
else:
safe_response = [[0, 0], []]
else:
# Unknown request type: default to empty embedding structure
safe_response = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
socket.send(msgpack.packb(safe_response, use_single_float=True))
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start() zmq_thread.start()

View File

@@ -117,7 +117,19 @@ Examples:
build_parser.add_argument("--complexity", type=int, default=64) build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1) build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument("--compact", action="store_true", default=True) build_parser.add_argument("--compact", action="store_true", default=True)
build_parser.add_argument(
"--no-compact",
dest="compact",
action="store_false",
help="Disable compact index storage (store full embeddings; higher storage)",
)
build_parser.add_argument("--recompute", action="store_true", default=True) build_parser.add_argument("--recompute", action="store_true", default=True)
build_parser.add_argument(
"--no-recompute",
dest="recompute",
action="store_false",
help="Disable embedding recomputation (store full embeddings; lower query latency)",
)
build_parser.add_argument( build_parser.add_argument(
"--file-types", "--file-types",
type=str, type=str,
@@ -138,6 +150,18 @@ Examples:
default=True, default=True,
help="Recompute embeddings (default: True)", help="Recompute embeddings (default: True)",
) )
search_parser.add_argument(
"--no-recompute-embeddings",
dest="recompute_embeddings",
action="store_false",
help="Disable embedding recomputation during search",
)
search_parser.add_argument(
"--no-recompute",
dest="recompute_embeddings",
action="store_false",
help="Alias for --no-recompute-embeddings",
)
search_parser.add_argument( search_parser.add_argument(
"--pruning-strategy", "--pruning-strategy",
choices=["global", "local", "proportional"], choices=["global", "local", "proportional"],
@@ -166,6 +190,18 @@ Examples:
default=True, default=True,
help="Recompute embeddings (default: True)", help="Recompute embeddings (default: True)",
) )
ask_parser.add_argument(
"--no-recompute-embeddings",
dest="recompute_embeddings",
action="store_false",
help="Disable embedding recomputation during ask",
)
ask_parser.add_argument(
"--no-recompute",
dest="recompute_embeddings",
action="store_false",
help="Alias for --no-recompute-embeddings",
)
ask_parser.add_argument( ask_parser.add_argument(
"--pruning-strategy", "--pruning-strategy",
choices=["global", "local", "proportional"], choices=["global", "local", "proportional"],

62
sky/leann-build.yaml Normal file
View File

@@ -0,0 +1,62 @@
name: leann-build
resources:
# Choose a GPU for fast embeddings (examples: L4, A10G, A100). CPU also works but is slower.
accelerators: L4:1
# Optionally pin a cloud, otherwise SkyPilot will auto-select
# cloud: aws
disk_size: 100
env:
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
index_name: my-index
docs: ./data
backend: hnsw # hnsw | diskann
complexity: 64
graph_degree: 32
num_threads: 8
# Embedding selection
embedding_mode: sentence-transformers # sentence-transformers | openai | mlx | ollama
embedding_model: facebook/contriever
# Storage/latency knobs
recompute: true # true => selective recomputation; false => store full embeddings
compact: true # for HNSW only: false when recompute=false
# Optional pass-through
extra_args: ""
# Sync local paths to the remote VM. Adjust as needed.
file_mounts:
# Example: mount your local data directory used for building
~/leann-data: ${docs}
setup: |
set -e
# Install uv (package manager)
curl -LsSf https://astral.sh/uv/install.sh | sh
export PATH="$HOME/.local/bin:$PATH"
# Install the LEANN CLI globally on the remote machine
uv tool install leann
run: |
export PATH="$HOME/.local/bin:$PATH"
# Derive flags from env
recompute_flag=""
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
recompute_flag="--no-recompute"
fi
compact_flag=""
if [ "${compact}" = "false" ] || [ "${compact}" = "0" ]; then
compact_flag="--no-compact"
fi
# Build command
leann build ${index_name} \
--docs ~/leann-data \
--backend ${backend} \
--complexity ${complexity} \
--graph-degree ${graph_degree} \
--num-threads ${num_threads} \
--embedding-mode ${embedding_mode} \
--embedding-model ${embedding_model} \
${recompute_flag} ${compact_flag} ${extra_args}