fix ruff errors and formatting

This commit is contained in:
yichuan520030910320
2025-07-27 02:22:54 -07:00
parent 383c6d8d7e
commit af1790395a
35 changed files with 166 additions and 107 deletions

View File

@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
including how to make donations to the Project Gutenberg Literary
Archive Foundation, how to help produce our new eBooks, and how to
subscribe to our email newsletter to hear about new eBooks.

View File

@@ -27,7 +27,10 @@ def load_sample_documents():
"title": "Intro to Python",
"content": "Python is a high-level, interpreted language known for simplicity.",
},
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
{
"title": "ML Basics",
"content": "Machine learning builds systems that learn from data.",
},
{
"title": "Data Structures",
"content": "Data structures like arrays, lists, and graphs organize data.",

View File

@@ -21,7 +21,9 @@ DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Googl
def create_leann_index_from_multiple_chrome_profiles(
profile_dirs: list[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1
profile_dirs: list[Path],
index_path: str = "chrome_history_index.leann",
max_count: int = -1,
):
"""
Create LEANN index from multiple Chrome profile data sources.

View File

@@ -474,7 +474,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
message_group, contact_name
)
doc = Document(
text=doc_content, metadata={"contact_name": contact_name}
text=doc_content,
metadata={"contact_name": contact_name},
)
docs.append(doc)
count += 1

View File

@@ -315,7 +315,11 @@ async def main():
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(
messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model
messages_dirs,
INDEX_PATH,
args.max_emails,
args.include_html,
args.embedding_model,
)
if index_path:

View File

@@ -92,7 +92,10 @@ def main():
help="Directory to store the index (default: mail_index_embedded)",
)
parser.add_argument(
"--max-emails", type=int, default=10000, help="Maximum number of emails to process"
"--max-emails",
type=int,
default=10000,
help="Maximum number of emails to process",
)
parser.add_argument(
"--include-html",
@@ -112,7 +115,10 @@ def main():
else:
print("Creating new index...")
index = create_and_save_index(
mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html
mail_path,
save_dir,
max_count=args.max_emails,
include_html=args.include_html,
)
if index:
queries = [

View File

@@ -347,7 +347,9 @@ def demo_aggregation():
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
aggregator = MultiVectorAggregator(
aggregation_method=method, spatial_clustering=True, cluster_distance_threshold=100.0
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0,
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)

View File

@@ -1 +0,0 @@

View File

@@ -72,7 +72,11 @@ def read_vector_raw(f, element_fmt_char):
def read_numpy_vector(f, np_dtype, struct_fmt_char):
"""Reads a vector into a NumPy array."""
count = -1 # Initialize count for robust error handling
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end="", flush=True)
print(
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
end="",
flush=True,
)
try:
count, data_bytes = read_vector_raw(f, struct_fmt_char)
print(f"Count={count}, Bytes={len(data_bytes)}")
@@ -647,7 +651,10 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
return False
except MemoryError as e:
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
print(
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
file=sys.stderr,
)
# Clean up potentially partially written output file?
try:
os.remove(output_filename)

View File

@@ -245,7 +245,11 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
# HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models(
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
search=query,
filter="text-generation",
sort="downloads",
direction=-1,
limit=limit,
)
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
@@ -582,7 +586,11 @@ class HFChat(LLMInterface):
# Tokenize input
inputs = self.tokenizer(
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
)
# Move inputs to device

View File

@@ -1,6 +1,6 @@
import json
import sqlite3
import xml.etree.ElementTree as ET
import xml.etree.ElementTree as ElementTree
from pathlib import Path
from typing import Annotated
@@ -26,7 +26,7 @@ def get_safe_path(s: str) -> str:
def process_history(history: str):
if history.startswith("<?xml") or history.startswith("<msg>"):
try:
root = ET.fromstring(history)
root = ElementTree.fromstring(history)
title = root.find(".//title").text if root.find(".//title") is not None else None
quoted = (
root.find(".//refermsg/content").text
@@ -52,7 +52,8 @@ def get_message(history: dict | str):
def export_chathistory(user_id: str):
res = requests.get(
"http://localhost:48065/wechat/chatlog", params={"userId": user_id, "count": 100000}
"http://localhost:48065/wechat/chatlog",
params={"userId": user_id, "count": 100000},
).json()
for i in range(len(res["chatLogs"])):
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
@@ -116,7 +117,8 @@ def export_sqlite(
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
for user in tqdm(all_users):
cursor.execute(
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user["arg"], user["title"])
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
(user["arg"], user["title"]),
)
usr_chatlog = export_chathistory(user["arg"])
for msg in usr_chatlog:

View File

@@ -58,7 +58,8 @@ class GraphWrapper:
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input, attention_mask=self.static_attention_mask
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
self.use_cuda_graph = True
else:
@@ -82,7 +83,10 @@ class GraphWrapper:
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask)
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
@@ -261,7 +265,10 @@ class Benchmark:
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features, out_features, bias=bias, has_fp16_weights=False
in_features,
out_features,
bias=bias,
has_fp16_weights=False,
)
# Copy weights and bias
@@ -350,8 +357,6 @@ class Benchmark:
# Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
@@ -427,7 +432,11 @@ class Benchmark:
else "cpu"
)
return torch.randint(
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long
0,
1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long,
)
def _run_inference(

View File

@@ -115,7 +115,13 @@ def main():
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
plt.plot(
BATCH_SIZES,
results_torch,
marker="o",
linestyle="-",
label=f"PyTorch ({device})",
)
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")

View File

@@ -170,7 +170,11 @@ class Benchmark:
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, self.config.seq_length), device=self.device, dtype=torch.long
0,
1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long,
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
@@ -256,7 +260,11 @@ def run_mlx_benchmark():
"""Run MLX-specific benchmark"""
if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "MLX not available"}
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available",
}
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
@@ -265,7 +273,11 @@ def run_mlx_benchmark():
results = benchmark.run()
if not results:
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "No valid results"}
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results",
}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])