feat: mlx

This commit is contained in:
Andy Lee
2025-07-13 02:13:04 -07:00
parent 71ef4b7d4c
commit 48dda1cb5b
4 changed files with 278 additions and 60 deletions

34
build_mlx_index.py Normal file
View File

@@ -0,0 +1,34 @@
from leann.api import LeannBuilder
import os
# Define the path for our new MLX-based index
INDEX_PATH = "./mlx_diskann_index/leann"
if os.path.exists(INDEX_PATH + ".meta.json"):
print(f"Index already exists at {INDEX_PATH}. Skipping build.")
else:
print("Initializing LeannBuilder with MLX support...")
# 1. Configure LeannBuilder to use MLX
builder = LeannBuilder(
backend_name="diskann",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
use_mlx=True
)
# 2. Add documents
print("Adding documents...")
docs = [
"MLX is an array framework for machine learning on Apple silicon.",
"It was designed by Apple's machine learning research team.",
"The mlx-community organization provides pre-trained models in MLX format.",
"It supports operations on multi-dimensional arrays.",
"Leann can now use MLX for its embedding models."
]
for doc in docs:
builder.add_text(doc)
# 3. Build the index
print(f"Building the MLX-based index at: {INDEX_PATH}")
builder.build_index(INDEX_PATH)
print("\nSuccessfully built the index with MLX embeddings!")
print(f"Check the metadata file: {INDEX_PATH}.meta.json")

View File

@@ -5,7 +5,6 @@ Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
import pickle import pickle
import argparse import argparse
import threading
import time import time
import json import json
from typing import Dict, Any, Optional, Union from typing import Dict, Any, Optional, Union
@@ -16,7 +15,6 @@ from contextlib import contextmanager
import zmq import zmq
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import pickle
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
@@ -154,6 +152,7 @@ def create_embedding_server_thread(
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128, max_batch_size=128,
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
use_mlx: bool = False,
): ):
""" """
在当前线程中创建并运行 embedding server 在当前线程中创建并运行 embedding server
@@ -172,36 +171,40 @@ def create_embedding_server_thread(
print(f"{RED}Port {zmq_port} is already in use{RESET}") print(f"{RED}Port {zmq_port} is already in use{RESET}")
return return
# 初始化模型 if use_mlx:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) from leann.api import compute_embeddings_mlx
import torch print("INFO: Using MLX for embeddings")
# 选择设备
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda")
print("INFO: Using CUDA device")
elif mps_available:
device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)")
else: else:
device = torch.device("cpu") # 初始化模型
print("INFO: Using CPU device") tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# 加载模型 # 选择设备
print(f"INFO: Loading model {model_name}") mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
model = AutoModel.from_pretrained(model_name).to(device).eval() cuda_available = torch.cuda.is_available()
# 优化模型 if cuda_available:
if cuda_available or mps_available: device = torch.device("cuda")
try: print("INFO: Using CUDA device")
model = model.half() elif mps_available:
model = torch.compile(model) device = torch.device("mps")
print(f"INFO: Using FP16 precision with model: {model_name}") print("INFO: Using MPS device (Apple Silicon)")
except Exception as e: else:
print(f"WARNING: Model optimization failed: {e}") device = torch.device("cpu")
print("INFO: Using CPU device")
# 加载模型
print(f"INFO: Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# 优化模型
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# Load passages from file if provided # Load passages from file if provided
if passages_file and os.path.exists(passages_file): if passages_file and os.path.exists(passages_file):
@@ -233,7 +236,7 @@ def create_embedding_server_thread(
self.start_time = 0 self.start_time = 0
self.end_time = 0 self.end_time = 0
if cuda_available: if not use_mlx and torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True) self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True)
else: else:
@@ -247,25 +250,25 @@ def create_embedding_server_thread(
self.end() self.end()
def start(self): def start(self):
if cuda_available: if not use_mlx and torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
self.start_event.record() self.start_event.record()
else: else:
if self.device.type == "mps": if not use_mlx and self.device.type == "mps":
torch.mps.synchronize() torch.mps.synchronize()
self.start_time = time.time() self.start_time = time.time()
def end(self): def end(self):
if cuda_available: if not use_mlx and torch.cuda.is_available():
self.end_event.record() self.end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
else: else:
if self.device.type == "mps": if not use_mlx and self.device.type == "mps":
torch.mps.synchronize() torch.mps.synchronize()
self.end_time = time.time() self.end_time = time.time()
def elapsed_time(self): def elapsed_time(self):
if cuda_available: if not use_mlx and torch.cuda.is_available():
return self.start_event.elapsed_time(self.end_event) / 1000.0 return self.start_event.elapsed_time(self.end_event) / 1000.0
else: else:
return self.end_time - self.start_time return self.end_time - self.start_time
@@ -273,7 +276,7 @@ def create_embedding_server_thread(
def print_elapsed(self): def print_elapsed(self):
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds") print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
def process_batch(texts_batch, ids_batch, missing_ids): def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""处理文本批次""" """处理文本批次"""
batch_size = len(texts_batch) batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}") print(f"INFO: Processing batch of size {batch_size}")
@@ -351,7 +354,7 @@ def create_embedding_server_thread(
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes") print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
e2e_start = time.time() e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup", device) lookup_timer = DeviceTimer("text lookup")
# 解析请求 # 解析请求
req_proto = embedding_pb2.NodeEmbeddingRequest() req_proto = embedding_pb2.NodeEmbeddingRequest()
@@ -397,18 +400,25 @@ def create_embedding_server_thread(
chunk_texts = texts[i:end_idx] chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx] chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids) if use_mlx:
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name)
else:
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk) all_embeddings.append(embeddings_chunk)
if cuda_available: if not use_mlx:
torch.cuda.empty_cache() if cuda_available:
elif device.type == "mps": torch.cuda.empty_cache()
torch.mps.empty_cache() elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings) hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}") print(f"INFO: Combined embeddings shape: {hidden.shape}")
else: else:
hidden = process_batch(texts, node_ids, missing_ids) if use_mlx:
hidden = compute_embeddings_mlx(texts, model_name)
else:
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
# 序列化响应 # 序列化响应
ser_start = time.time() ser_start = time.time()
@@ -429,16 +439,16 @@ def create_embedding_server_thread(
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds") print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if device.type == "cuda": if not use_mlx:
torch.cuda.synchronize() if device.type == "cuda":
elif device.type == "mps": torch.cuda.synchronize()
torch.mps.synchronize() elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time() e2e_end = time.time()
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
except zmq.Again: except zmq.Again:
print("INFO: ZMQ socket timeout, continuing to listen") print("INFO: ZMQ socket timeout, continuing to listen")
# REP套接字不需要重新创建只需要继续监听
continue continue
except Exception as e: except Exception as e:
print(f"ERROR: Error in ZMQ server: {e}") print(f"ERROR: Error in ZMQ server: {e}")
@@ -460,7 +470,6 @@ def create_embedding_server_thread(
raise raise
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
def create_embedding_server( def create_embedding_server(
domain="demo", domain="demo",
load_passages=True, load_passages=True,
@@ -473,12 +482,13 @@ def create_embedding_server(
lazy_load_passages=False, lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
use_mlx: bool = False,
): ):
""" """
原有的 create_embedding_server 函数保持不变 原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行 这个是阻塞版本,用于直接运行
""" """
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file) create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx)
if __name__ == "__main__": if __name__ == "__main__":
@@ -495,6 +505,7 @@ if __name__ == "__main__":
parser.add_argument("--lazy-load-passages", action="store_true", default=True) parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name") help="Embedding model name")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
args = parser.parse_args() args = parser.parse_args()
create_embedding_server( create_embedding_server(
@@ -509,4 +520,5 @@ if __name__ == "__main__":
lazy_load_passages=args.lazy_load_passages, lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name, model_name=args.model_name,
passages_file=args.passages_file, passages_file=args.passages_file,
use_mlx=args.use_mlx,
) )

View File

@@ -1,3 +1,4 @@
""" """
This file contains the core API for the LEANN project, now definitively updated This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code. with the correct, original embedding logic from the user's reference code.
@@ -17,8 +18,10 @@ from .interface import LeannBackendFactoryInterface
# --- The Correct, Verified Embedding Logic from old_code.py --- # --- The Correct, Verified Embedding Logic from old_code.py ---
def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: def compute_embeddings(chunks: List[str], model_name: str, use_mlx: bool = False) -> np.ndarray:
"""Computes embeddings using sentence-transformers for consistent results.""" """Computes embeddings using sentence-transformers or MLX for consistent results."""
if use_mlx:
return compute_embeddings_mlx(chunks, model_name)
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ImportError as e: except ImportError as e:
@@ -44,6 +47,45 @@ def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
return embeddings return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
except ImportError as e:
raise RuntimeError(
f"MLX or related libraries not available. Install with: pip install mlx mlx-lm"
) from e
print(f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}'...")
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process each chunk
all_embeddings = []
for chunk in chunks:
# Tokenize
token_ids = tokenizer.encode(chunk)
# Convert to MLX array and add batch dimension
input_ids = mx.array([token_ids])
# Get embeddings
embeddings = model(input_ids)
# Mean pooling (since we only have one sequence, just take the mean)
pooled = embeddings.mean(axis=1) # Shape: (1, hidden_size)
# Convert individual embedding to numpy via list (to handle bfloat16)
pooled_list = pooled[0].tolist() # Remove batch dimension and convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
# --- Core API Classes (Restored and Unchanged) --- # --- Core API Classes (Restored and Unchanged) ---
@dataclass @dataclass
@@ -83,7 +125,7 @@ class PassageManager:
raise KeyError(f"Passage ID not found: {passage_id}") raise KeyError(f"Passage ID not found: {passage_id}")
class LeannBuilder: class LeannBuilder:
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs): def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, use_mlx: bool = False, **backend_kwargs):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
@@ -91,6 +133,7 @@ class LeannBuilder:
self.backend_factory = backend_factory self.backend_factory = backend_factory
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions self.dimensions = dimensions
self.use_mlx = use_mlx
self.backend_kwargs = backend_kwargs self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = [] self.chunks: List[Dict[str, Any]] = []
@@ -102,7 +145,7 @@ class LeannBuilder:
def build_index(self, index_path: str): def build_index(self, index_path: str):
if not self.chunks: raise ValueError("No chunks added.") if not self.chunks: raise ValueError("No chunks added.")
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0]) if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0])
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_name = path.name index_name = path.name
@@ -118,7 +161,7 @@ class LeannBuilder:
offset_map[chunk["id"]] = offset offset_map[chunk["id"]] = offset
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f) with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(texts_to_embed, self.embedding_model) embeddings = compute_embeddings(texts_to_embed, self.embedding_model, self.use_mlx)
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions} current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
@@ -126,7 +169,7 @@ class LeannBuilder:
leann_meta_path = index_dir / f"{index_name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "use_mlx": self.use_mlx,
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}] "passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
} }
@@ -145,6 +188,7 @@ class LeannSearcher:
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f) with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
backend_name = self.meta_data['backend_name'] backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model'] self.embedding_model = self.meta_data['embedding_model']
self.use_mlx = self.meta_data.get('use_mlx', False)
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', [])) self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
@@ -157,7 +201,7 @@ class LeannSearcher:
print(f" Top_k: {top_k}") print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}") print(f" Search kwargs: {search_kwargs}")
query_embedding = compute_embeddings([query], self.embedding_model) query_embedding = compute_embeddings([query], self.embedding_model, self.use_mlx)
print(f" Generated embedding shape: {query_embedding.shape}") print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}") print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}") print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")

View File

@@ -0,0 +1,128 @@
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from sentence_transformers import SentenceTransformer
import mlx.core as mx
from mlx_lm import load
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
input_ids.append(padded)
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Load MLX model
print(f"Loading MLX model: {MODEL_NAME_MLX}")
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
print("MLX model loaded.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()