fix ruff errors and formatting

This commit is contained in:
yichuan520030910320
2025-07-27 02:22:54 -07:00
parent 383c6d8d7e
commit af1790395a
35 changed files with 166 additions and 107 deletions

View File

@@ -8,4 +8,4 @@ on:
jobs: jobs:
build: build:
uses: ./.github/workflows/build-reusable.yml uses: ./.github/workflows/build-reusable.yml

View File

@@ -17,23 +17,23 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ inputs.ref }} ref: ${{ inputs.ref }}
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.11' python-version: '3.11'
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v4 uses: astral-sh/setup-uv@v4
- name: Install ruff - name: Install ruff
run: | run: |
uv tool install ruff uv tool install ruff
- name: Run ruff check - name: Run ruff check
run: | run: |
ruff check . ruff check .
- name: Run ruff format check - name: Run ruff format check
run: | run: |
ruff format --check . ruff format --check .
@@ -65,40 +65,40 @@ jobs:
- os: macos-latest - os: macos-latest
python: '3.13' python: '3.13'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ inputs.ref }} ref: ${{ inputs.ref }}
submodules: recursive submodules: recursive
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python }} python-version: ${{ matrix.python }}
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v4 uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu) - name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux' if: runner.os == 'Linux'
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \ sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN # Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS) - name: Install system dependencies (macOS)
if: runner.os == 'macOS' if: runner.os == 'macOS'
run: | run: |
brew install llvm libomp boost protobuf zeromq brew install llvm libomp boost protobuf zeromq
- name: Install build dependencies - name: Install build dependencies
run: | run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11 uv pip install --system scikit-build-core numpy swig Cython pybind11
@@ -107,7 +107,7 @@ jobs:
else else
uv pip install --system delocate uv pip install --system delocate
fi fi
- name: Build packages - name: Build packages
run: | run: |
# Build core (platform independent) # Build core (platform independent)
@@ -116,7 +116,7 @@ jobs:
uv build uv build
cd ../.. cd ../..
fi fi
# Build HNSW backend # Build HNSW backend
cd packages/leann-backend-hnsw cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then if [ "${{ matrix.os }}" == "macos-latest" ]; then
@@ -125,7 +125,7 @@ jobs:
uv build --wheel --python python uv build --wheel --python python
fi fi
cd ../.. cd ../..
# Build DiskANN backend # Build DiskANN backend
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then if [ "${{ matrix.os }}" == "macos-latest" ]; then
@@ -134,14 +134,14 @@ jobs:
uv build --wheel --python python uv build --wheel --python python
fi fi
cd ../.. cd ../..
# Build meta package (platform independent) # Build meta package (platform independent)
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann cd packages/leann
uv build uv build
cd ../.. cd ../..
fi fi
- name: Repair wheels (Linux) - name: Repair wheels (Linux)
if: runner.os == 'Linux' if: runner.os == 'Linux'
run: | run: |
@@ -153,7 +153,7 @@ jobs:
mv dist_repaired dist mv dist_repaired dist
fi fi
cd ../.. cd ../..
# Repair DiskANN wheel # Repair DiskANN wheel
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [ -d dist ]; then if [ -d dist ]; then
@@ -162,7 +162,7 @@ jobs:
mv dist_repaired dist mv dist_repaired dist
fi fi
cd ../.. cd ../..
- name: Repair wheels (macOS) - name: Repair wheels (macOS)
if: runner.os == 'macOS' if: runner.os == 'macOS'
run: | run: |
@@ -174,7 +174,7 @@ jobs:
mv dist_repaired dist mv dist_repaired dist
fi fi
cd ../.. cd ../..
# Repair DiskANN wheel # Repair DiskANN wheel
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [ -d dist ]; then if [ -d dist ]; then
@@ -183,14 +183,14 @@ jobs:
mv dist_repaired dist mv dist_repaired dist
fi fi
cd ../.. cd ../..
- name: List built packages - name: List built packages
run: | run: |
echo "📦 Built packages:" echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: packages-${{ matrix.os }}-py${{ matrix.python }} name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/ path: packages/*/dist/

View File

@@ -16,10 +16,10 @@ jobs:
contents: write contents: write
outputs: outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }} commit-sha: ${{ steps.push.outputs.commit-sha }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Validate version - name: Validate version
run: | run: |
# Remove 'v' prefix if present for validation # Remove 'v' prefix if present for validation
@@ -30,7 +30,7 @@ jobs:
exit 1 exit 1
fi fi
echo "✅ Version format valid: ${{ inputs.version }}" echo "✅ Version format valid: ${{ inputs.version }}"
- name: Update versions and push - name: Update versions and push
id: push id: push
run: | run: |
@@ -38,7 +38,7 @@ jobs:
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2) CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION" echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}" echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update" echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD) COMMIT_SHA=$(git rev-parse HEAD)
@@ -52,7 +52,7 @@ jobs:
COMMIT_SHA=$(git rev-parse HEAD) COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA" echo "✅ Pushed version update: $COMMIT_SHA"
fi fi
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
build-packages: build-packages:
@@ -60,7 +60,7 @@ jobs:
needs: update-version needs: update-version
uses: ./.github/workflows/build-reusable.yml uses: ./.github/workflows/build-reusable.yml
with: with:
ref: 'main' ref: 'main'
publish: publish:
name: Publish and Release name: Publish and Release
@@ -69,26 +69,26 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: 'main' ref: 'main'
- name: Download all artifacts - name: Download all artifacts
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
path: dist-artifacts path: dist-artifacts
- name: Collect packages - name: Collect packages
run: | run: |
mkdir -p dist mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \; find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \; find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:" echo "📦 Packages to publish:"
ls -la dist/ ls -la dist/
- name: Publish to PyPI - name: Publish to PyPI
env: env:
TWINE_USERNAME: __token__ TWINE_USERNAME: __token__
@@ -98,12 +98,12 @@ jobs:
echo "❌ PYPI_API_TOKEN not configured!" echo "❌ PYPI_API_TOKEN not configured!"
exit 1 exit 1
fi fi
pip install twine pip install twine
twine upload dist/* --skip-existing --verbose twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!" echo "✅ Published to PyPI!"
- name: Create release - name: Create release
run: | run: |
# Check if tag already exists # Check if tag already exists
@@ -114,7 +114,7 @@ jobs:
git push origin "v${{ inputs.version }}" git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}" echo "✅ Created and pushed tag v${{ inputs.version }}"
fi fi
# Check if release already exists # Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation" echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
@@ -126,4 +126,4 @@ jobs:
echo "✅ Created GitHub release v${{ inputs.version }}" echo "✅ Created GitHub release v${{ inputs.version }}"
fi fi
env: env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

4
.gitignore vendored
View File

@@ -9,7 +9,7 @@ demo/indices/
outputs/ outputs/
*.pkl *.pkl
*.pdf *.pdf
*.idx *.idx
*.map *.map
.history/ .history/
lm_eval.egg-info/ lm_eval.egg-info/
@@ -85,4 +85,4 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json *.meta.json
*.passages.json *.passages.json
batchtest.py batchtest.py

View File

@@ -19,4 +19,4 @@ That's it! The workflow will automatically:
- ✅ Publish to PyPI - ✅ Publish to PyPI
- ✅ Create GitHub tag and release - ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions Check progress: https://github.com/yichuan-w/LEANN/actions

View File

@@ -8,4 +8,4 @@ We welcome contributions! Leann is built by the community, for the community.
- 💡 **Feature Requests**: Have an idea? We'd love to hear it! - 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels - 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible - 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results - 🧪 **Benchmarks**: Share your performance results

View File

@@ -7,4 +7,4 @@ You can speed up the process by using a lightweight embedding model. Add this to
```bash ```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2 --embedding-model sentence-transformers/all-MiniLM-L6-v2
``` ```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters) **Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

View File

@@ -19,4 +19,4 @@
- **Simple Python API** - Get started in minutes - **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms - **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment - **Comprehensive examples** - From basic usage to production deployment

View File

@@ -18,4 +18,4 @@
- [ ] Integration with LangChain/LlamaIndex - [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search - [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion - [ ] Query rewrtiting, rerank and expansion

View File

@@ -1,5 +1,5 @@
The Project Gutenberg eBook of Pride and Prejudice The Project Gutenberg eBook of Pride and Prejudice
This ebook is for the use of anyone anywhere in the United States and This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms whatsoever. You may copy it, give it away or re-use it under the terms
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE *** *** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
Updated editions will replace the previous one—the old editions will Updated editions will replace the previous one—the old editions will
be renamed. be renamed.
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
at www.gutenberg.org. If you at www.gutenberg.org. If you
are not located in the United States, you will have to check the laws are not located in the United States, you will have to check the laws
of the country where you are located before using this eBook. of the country where you are located before using this eBook.
1.E.2. If an individual Project Gutenberg™ electronic work is 1.E.2. If an individual Project Gutenberg™ electronic work is
derived from texts not protected by U.S. copyright law (does not derived from texts not protected by U.S. copyright law (does not
contain a notice indicating that it is posted with permission of the contain a notice indicating that it is posted with permission of the
@@ -14724,7 +14724,7 @@ provided that:
Gutenberg Literary Archive Foundation at the address specified in Gutenberg Literary Archive Foundation at the address specified in
Section 4, “Information about donations to the Project Gutenberg Section 4, “Information about donations to the Project Gutenberg
Literary Archive Foundation.” Literary Archive Foundation.”
• You provide a full refund of any money paid by a user who notifies • You provide a full refund of any money paid by a user who notifies
you in writing (or by e-mail) within 30 days of receipt that s/he you in writing (or by e-mail) within 30 days of receipt that s/he
does not agree to the terms of the full Project Gutenberg™ does not agree to the terms of the full Project Gutenberg™
@@ -14732,15 +14732,15 @@ provided that:
copies of the works possessed in a physical medium and discontinue copies of the works possessed in a physical medium and discontinue
all use of and all access to other copies of Project Gutenberg™ all use of and all access to other copies of Project Gutenberg™
works. works.
• You provide, in accordance with paragraph 1.F.3, a full refund of • You provide, in accordance with paragraph 1.F.3, a full refund of
any money paid for a work or a replacement copy, if a defect in the any money paid for a work or a replacement copy, if a defect in the
electronic work is discovered and reported to you within 90 days of electronic work is discovered and reported to you within 90 days of
receipt of the work. receipt of the work.
• You comply with all other terms of this agreement for free • You comply with all other terms of this agreement for free
distribution of Project Gutenberg™ works. distribution of Project Gutenberg™ works.
1.E.9. If you wish to charge a fee or distribute a Project 1.E.9. If you wish to charge a fee or distribute a Project
Gutenberg™ electronic work or group of works on different terms than Gutenberg™ electronic work or group of works on different terms than
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
including how to make donations to the Project Gutenberg Literary including how to make donations to the Project Gutenberg Literary
Archive Foundation, how to help produce our new eBooks, and how to Archive Foundation, how to help produce our new eBooks, and how to
subscribe to our email newsletter to hear about new eBooks. subscribe to our email newsletter to hear about new eBooks.

View File

@@ -27,7 +27,10 @@ def load_sample_documents():
"title": "Intro to Python", "title": "Intro to Python",
"content": "Python is a high-level, interpreted language known for simplicity.", "content": "Python is a high-level, interpreted language known for simplicity.",
}, },
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."}, {
"title": "ML Basics",
"content": "Machine learning builds systems that learn from data.",
},
{ {
"title": "Data Structures", "title": "Data Structures",
"content": "Data structures like arrays, lists, and graphs organize data.", "content": "Data structures like arrays, lists, and graphs organize data.",

View File

@@ -21,7 +21,9 @@ DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Googl
def create_leann_index_from_multiple_chrome_profiles( def create_leann_index_from_multiple_chrome_profiles(
profile_dirs: list[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1 profile_dirs: list[Path],
index_path: str = "chrome_history_index.leann",
max_count: int = -1,
): ):
""" """
Create LEANN index from multiple Chrome profile data sources. Create LEANN index from multiple Chrome profile data sources.

View File

@@ -474,7 +474,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
message_group, contact_name message_group, contact_name
) )
doc = Document( doc = Document(
text=doc_content, metadata={"contact_name": contact_name} text=doc_content,
metadata={"contact_name": contact_name},
) )
docs.append(doc) docs.append(doc)
count += 1 count += 1

View File

@@ -315,7 +315,11 @@ async def main():
# Create or load the LEANN index from all sources # Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources( index_path = create_leann_index_from_multiple_sources(
messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model messages_dirs,
INDEX_PATH,
args.max_emails,
args.include_html,
args.embedding_model,
) )
if index_path: if index_path:

View File

@@ -92,7 +92,10 @@ def main():
help="Directory to store the index (default: mail_index_embedded)", help="Directory to store the index (default: mail_index_embedded)",
) )
parser.add_argument( parser.add_argument(
"--max-emails", type=int, default=10000, help="Maximum number of emails to process" "--max-emails",
type=int,
default=10000,
help="Maximum number of emails to process",
) )
parser.add_argument( parser.add_argument(
"--include-html", "--include-html",
@@ -112,7 +115,10 @@ def main():
else: else:
print("Creating new index...") print("Creating new index...")
index = create_and_save_index( index = create_and_save_index(
mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html mail_path,
save_dir,
max_count=args.max_emails,
include_html=args.include_html,
) )
if index: if index:
queries = [ queries = [

View File

@@ -347,7 +347,9 @@ def demo_aggregation():
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}") print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
aggregator = MultiVectorAggregator( aggregator = MultiVectorAggregator(
aggregation_method=method, spatial_clustering=True, cluster_distance_threshold=100.0 aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0,
) )
aggregated = aggregator.aggregate_results(mock_results, top_k=5) aggregated = aggregator.aggregate_results(mock_results, top_k=5)

View File

@@ -1 +0,0 @@

View File

@@ -16,4 +16,4 @@ wheel.packages = ["leann_backend_diskann"]
editable.mode = "redirect" editable.mode = "redirect"
cmake.build-type = "Release" cmake.build-type = "Release"
build.verbose = true build.verbose = true
build.tool-args = ["-j8"] build.tool-args = ["-j8"]

View File

@@ -2,12 +2,12 @@ syntax = "proto3";
package protoembedding; package protoembedding;
message NodeEmbeddingRequest { message NodeEmbeddingRequest {
repeated uint32 node_ids = 1; repeated uint32 node_ids = 1;
} }
message NodeEmbeddingResponse { message NodeEmbeddingResponse {
bytes embeddings_data = 1; // All embedded binary datas bytes embeddings_data = 1; // All embedded binary datas
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim] repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
repeated uint32 missing_ids = 3; // Missing node ids repeated uint32 missing_ids = 3; // Missing node ids
} }

View File

@@ -52,4 +52,4 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation # IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE) set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss) add_subdirectory(third_party/faiss)

View File

@@ -72,7 +72,11 @@ def read_vector_raw(f, element_fmt_char):
def read_numpy_vector(f, np_dtype, struct_fmt_char): def read_numpy_vector(f, np_dtype, struct_fmt_char):
"""Reads a vector into a NumPy array.""" """Reads a vector into a NumPy array."""
count = -1 # Initialize count for robust error handling count = -1 # Initialize count for robust error handling
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end="", flush=True) print(
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
end="",
flush=True,
)
try: try:
count, data_bytes = read_vector_raw(f, struct_fmt_char) count, data_bytes = read_vector_raw(f, struct_fmt_char)
print(f"Count={count}, Bytes={len(data_bytes)}") print(f"Count={count}, Bytes={len(data_bytes)}")
@@ -647,7 +651,10 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Error: Input file not found: {input_filename}", file=sys.stderr) print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
return False return False
except MemoryError as e: except MemoryError as e:
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr) print(
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
file=sys.stderr,
)
# Clean up potentially partially written output file? # Clean up potentially partially written output file?
try: try:
os.remove(output_filename) os.remove(output_filename)

View File

@@ -9,7 +9,7 @@ name = "leann-backend-hnsw"
version = "0.1.14" version = "0.1.14"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.1.14", "leann-core==0.1.14",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",
@@ -24,4 +24,4 @@ build.tool-args = ["-j8"]
# CMake definitions to optimize compilation # CMake definitions to optimize compilation
[tool.scikit-build.cmake.define] [tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8" CMAKE_BUILD_PARALLEL_LEVEL = "8"

View File

@@ -46,4 +46,4 @@ colab = [
leann = "leann.cli:main" leann = "leann.cli:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]

View File

@@ -245,7 +245,11 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
# HF Hub's search is already fuzzy! It handles typos and partial matches # HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models( models = list_models(
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit search=query,
filter="text-generation",
sort="downloads",
direction=-1,
limit=limit,
) )
model_names = [model.id if hasattr(model, "id") else str(model) for model in models] model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
@@ -582,7 +586,11 @@ class HFChat(LLMInterface):
# Tokenize input # Tokenize input
inputs = self.tokenizer( inputs = self.tokenizer(
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048 formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
) )
# Move inputs to device # Move inputs to device

View File

@@ -37,4 +37,4 @@ For full documentation, visit [https://leann.readthedocs.io](https://leann.readt
## License ## License
MIT License MIT License

View File

@@ -39,4 +39,4 @@ diskann = [
Homepage = "https://github.com/yourusername/leann" Homepage = "https://github.com/yourusername/leann"
Documentation = "https://leann.readthedocs.io" Documentation = "https://leann.readthedocs.io"
Repository = "https://github.com/yourusername/leann" Repository = "https://github.com/yourusername/leann"
Issues = "https://github.com/yourusername/leann/issues" Issues = "https://github.com/yourusername/leann/issues"

View File

@@ -1,6 +1,6 @@
import json import json
import sqlite3 import sqlite3
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ElementTree
from pathlib import Path from pathlib import Path
from typing import Annotated from typing import Annotated
@@ -26,7 +26,7 @@ def get_safe_path(s: str) -> str:
def process_history(history: str): def process_history(history: str):
if history.startswith("<?xml") or history.startswith("<msg>"): if history.startswith("<?xml") or history.startswith("<msg>"):
try: try:
root = ET.fromstring(history) root = ElementTree.fromstring(history)
title = root.find(".//title").text if root.find(".//title") is not None else None title = root.find(".//title").text if root.find(".//title") is not None else None
quoted = ( quoted = (
root.find(".//refermsg/content").text root.find(".//refermsg/content").text
@@ -52,7 +52,8 @@ def get_message(history: dict | str):
def export_chathistory(user_id: str): def export_chathistory(user_id: str):
res = requests.get( res = requests.get(
"http://localhost:48065/wechat/chatlog", params={"userId": user_id, "count": 100000} "http://localhost:48065/wechat/chatlog",
params={"userId": user_id, "count": 100000},
).json() ).json()
for i in range(len(res["chatLogs"])): for i in range(len(res["chatLogs"])):
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"]) res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
@@ -116,7 +117,8 @@ def export_sqlite(
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json() all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
for user in tqdm(all_users): for user in tqdm(all_users):
cursor.execute( cursor.execute(
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user["arg"], user["title"]) "INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
(user["arg"], user["title"]),
) )
usr_chatlog = export_chathistory(user["arg"]) usr_chatlog = export_chathistory(user["arg"])
for msg in usr_chatlog: for msg in usr_chatlog:

View File

@@ -19,16 +19,16 @@ uv pip install build twine delocate auditwheel scikit-build-core cmake pybind11
build_package() { build_package() {
local package_dir=$1 local package_dir=$1
local package_name=$(basename $package_dir) local package_name=$(basename $package_dir)
echo "Building $package_name..." echo "Building $package_name..."
cd $package_dir cd $package_dir
# Clean previous builds # Clean previous builds
rm -rf dist/ build/ _skbuild/ rm -rf dist/ build/ _skbuild/
# Build directly with pip wheel (avoids sdist issues) # Build directly with pip wheel (avoids sdist issues)
pip wheel . --no-deps -w dist pip wheel . --no-deps -w dist
# Repair wheel for binary packages # Repair wheel for binary packages
if [[ "$package_name" != "leann-core" ]] && [[ "$package_name" != "leann" ]]; then if [[ "$package_name" != "leann-core" ]] && [[ "$package_name" != "leann" ]]; then
if [[ "$OSTYPE" == "darwin"* ]]; then if [[ "$OSTYPE" == "darwin"* ]]; then
@@ -57,7 +57,7 @@ build_package() {
fi fi
fi fi
fi fi
echo "Built wheels in $package_dir/dist/" echo "Built wheels in $package_dir/dist/"
ls -la dist/ ls -la dist/
cd - > /dev/null cd - > /dev/null
@@ -84,4 +84,4 @@ else
fi fi
echo -e "\nBuild complete! Test with:" echo -e "\nBuild complete! Test with:"
echo "uv pip install packages/*/dist/*.whl" echo "uv pip install packages/*/dist/*.whl"

View File

@@ -28,4 +28,4 @@ else
fi fi
echo "✅ Version updated to $NEW_VERSION" echo "✅ Version updated to $NEW_VERSION"
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION" echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"

View File

@@ -15,4 +15,4 @@ VERSION=$1
git add . && git commit -m "chore: bump version to $VERSION" && git push git add . && git commit -m "chore: bump version to $VERSION" && git push
# Create release (triggers CI) # Create release (triggers CI)
gh release create v$VERSION --generate-notes gh release create v$VERSION --generate-notes

View File

@@ -27,4 +27,4 @@ else
else else
echo "Cancelled" echo "Cancelled"
fi fi
fi fi

View File

@@ -58,7 +58,8 @@ class GraphWrapper:
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph): with torch.cuda.graph(self.graph):
self.static_output = self.model( self.static_output = self.model(
input_ids=self.static_input, attention_mask=self.static_attention_mask input_ids=self.static_input,
attention_mask=self.static_attention_mask,
) )
self.use_cuda_graph = True self.use_cuda_graph = True
else: else:
@@ -82,7 +83,10 @@ class GraphWrapper:
def _warmup(self, num_warmup: int = 3): def _warmup(self, num_warmup: int = 3):
with torch.no_grad(): with torch.no_grad():
for _ in range(num_warmup): for _ in range(num_warmup):
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask) self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph: if self.use_cuda_graph:
@@ -261,7 +265,10 @@ class Benchmark:
# print size # print size
print(f"in_features: {in_features}, out_features: {out_features}") print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt( new_module = bnb.nn.Linear8bitLt(
in_features, out_features, bias=bias, has_fp16_weights=False in_features,
out_features,
bias=bias,
has_fp16_weights=False,
) )
# Copy weights and bias # Copy weights and bias
@@ -350,8 +357,6 @@ class Benchmark:
# Try xformers if available (only on CUDA) # Try xformers if available (only on CUDA)
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"): if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention") print("- Enabled xformers memory efficient attention")
@@ -427,7 +432,11 @@ class Benchmark:
else "cpu" else "cpu"
) )
return torch.randint( return torch.randint(
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long 0,
1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long,
) )
def _run_inference( def _run_inference(

View File

@@ -7,7 +7,7 @@ This directory contains comprehensive sanity checks for the Leann system, ensuri
### `test_distance_functions.py` ### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend: Tests all supported distance functions across DiskANN backend:
-**MIPS** (Maximum Inner Product Search) -**MIPS** (Maximum Inner Product Search)
-**L2** (Euclidean Distance) -**L2** (Euclidean Distance)
-**Cosine** (Cosine Similarity) -**Cosine** (Cosine Similarity)
```bash ```bash
@@ -27,7 +27,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
### `test_sanity_check.py` ### `test_sanity_check.py`
Comprehensive end-to-end verification including: Comprehensive end-to-end verification including:
- Distance function testing - Distance function testing
- Embedding model compatibility - Embedding model compatibility
- Search result correctness validation - Search result correctness validation
- Backend integration testing - Backend integration testing
@@ -64,7 +64,7 @@ When all tests pass, you should see:
``` ```
📊 测试结果总结: 📊 测试结果总结:
mips : ✅ 通过 mips : ✅ 通过
l2 : ✅ 通过 l2 : ✅ 通过
cosine : ✅ 通过 cosine : ✅ 通过
🎉 测试完成! 🎉 测试完成!
@@ -98,7 +98,7 @@ pkill -f "embedding_server"
### Typical Timing (3 documents, consumer hardware): ### Typical Timing (3 documents, consumer hardware):
- **Index Building**: 2-5 seconds per distance function - **Index Building**: 2-5 seconds per distance function
- **Search Query**: 50-200ms - **Search Query**: 50-200ms
- **Recompute Mode**: 5-15 seconds (higher accuracy) - **Recompute Mode**: 5-15 seconds (higher accuracy)
### Memory Usage: ### Memory Usage:
@@ -117,4 +117,4 @@ These tests are designed to be run in automated environments:
uv run python tests/sanity_checks/test_l2_verification.py uv run python tests/sanity_checks/test_l2_verification.py
``` ```
The tests are deterministic and should produce consistent results across different platforms. The tests are deterministic and should produce consistent results across different platforms.

View File

@@ -115,7 +115,13 @@ def main():
# --- Plotting --- # --- Plotting ---
print("\n--- Generating Plot ---") print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})") plt.plot(
BATCH_SIZES,
results_torch,
marker="o",
linestyle="-",
label=f"PyTorch ({device})",
)
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX") plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}") plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")

View File

@@ -170,7 +170,11 @@ class Benchmark:
def _create_random_batch(self, batch_size: int) -> torch.Tensor: def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint( return torch.randint(
0, 1000, (batch_size, self.config.seq_length), device=self.device, dtype=torch.long 0,
1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long,
) )
def _run_inference(self, input_ids: torch.Tensor) -> float: def _run_inference(self, input_ids: torch.Tensor) -> float:
@@ -256,7 +260,11 @@ def run_mlx_benchmark():
"""Run MLX-specific benchmark""" """Run MLX-specific benchmark"""
if not MLX_AVAILABLE: if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark") print("MLX not available, skipping MLX benchmark")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "MLX not available"} return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available",
}
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True) config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
@@ -265,7 +273,11 @@ def run_mlx_benchmark():
results = benchmark.run() results = benchmark.run()
if not results: if not results:
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "No valid results"} return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results",
}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results) max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results]) avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])