fix ruff errors and formatting
This commit is contained in:
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
|||||||
including how to make donations to the Project Gutenberg Literary
|
including how to make donations to the Project Gutenberg Literary
|
||||||
Archive Foundation, how to help produce our new eBooks, and how to
|
Archive Foundation, how to help produce our new eBooks, and how to
|
||||||
subscribe to our email newsletter to hear about new eBooks.
|
subscribe to our email newsletter to hear about new eBooks.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,10 @@ def load_sample_documents():
|
|||||||
"title": "Intro to Python",
|
"title": "Intro to Python",
|
||||||
"content": "Python is a high-level, interpreted language known for simplicity.",
|
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||||
},
|
},
|
||||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
{
|
||||||
|
"title": "ML Basics",
|
||||||
|
"content": "Machine learning builds systems that learn from data.",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"title": "Data Structures",
|
"title": "Data Structures",
|
||||||
"content": "Data structures like arrays, lists, and graphs organize data.",
|
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Googl
|
|||||||
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_chrome_profiles(
|
def create_leann_index_from_multiple_chrome_profiles(
|
||||||
profile_dirs: list[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1
|
profile_dirs: list[Path],
|
||||||
|
index_path: str = "chrome_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from multiple Chrome profile data sources.
|
Create LEANN index from multiple Chrome profile data sources.
|
||||||
|
|||||||
@@ -474,7 +474,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
message_group, contact_name
|
message_group, contact_name
|
||||||
)
|
)
|
||||||
doc = Document(
|
doc = Document(
|
||||||
text=doc_content, metadata={"contact_name": contact_name}
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
)
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|||||||
@@ -315,7 +315,11 @@ async def main():
|
|||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
# Create or load the LEANN index from all sources
|
||||||
index_path = create_leann_index_from_multiple_sources(
|
index_path = create_leann_index_from_multiple_sources(
|
||||||
messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model
|
messages_dirs,
|
||||||
|
INDEX_PATH,
|
||||||
|
args.max_emails,
|
||||||
|
args.include_html,
|
||||||
|
args.embedding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if index_path:
|
if index_path:
|
||||||
|
|||||||
@@ -92,7 +92,10 @@ def main():
|
|||||||
help="Directory to store the index (default: mail_index_embedded)",
|
help="Directory to store the index (default: mail_index_embedded)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-emails", type=int, default=10000, help="Maximum number of emails to process"
|
"--max-emails",
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help="Maximum number of emails to process",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--include-html",
|
"--include-html",
|
||||||
@@ -112,7 +115,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print("Creating new index...")
|
print("Creating new index...")
|
||||||
index = create_and_save_index(
|
index = create_and_save_index(
|
||||||
mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html
|
mail_path,
|
||||||
|
save_dir,
|
||||||
|
max_count=args.max_emails,
|
||||||
|
include_html=args.include_html,
|
||||||
)
|
)
|
||||||
if index:
|
if index:
|
||||||
queries = [
|
queries = [
|
||||||
|
|||||||
@@ -347,7 +347,9 @@ def demo_aggregation():
|
|||||||
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||||
|
|
||||||
aggregator = MultiVectorAggregator(
|
aggregator = MultiVectorAggregator(
|
||||||
aggregation_method=method, spatial_clustering=True, cluster_distance_threshold=100.0
|
aggregation_method=method,
|
||||||
|
spatial_clustering=True,
|
||||||
|
cluster_distance_threshold=100.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,11 @@ def read_vector_raw(f, element_fmt_char):
|
|||||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||||
"""Reads a vector into a NumPy array."""
|
"""Reads a vector into a NumPy array."""
|
||||||
count = -1 # Initialize count for robust error handling
|
count = -1 # Initialize count for robust error handling
|
||||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end="", flush=True)
|
print(
|
||||||
|
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||||
@@ -647,7 +651,10 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||||
return False
|
return False
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
print(
|
||||||
|
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
# Clean up potentially partially written output file?
|
# Clean up potentially partially written output file?
|
||||||
try:
|
try:
|
||||||
os.remove(output_filename)
|
os.remove(output_filename)
|
||||||
|
|||||||
@@ -245,7 +245,11 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
|||||||
|
|
||||||
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||||
models = list_models(
|
models = list_models(
|
||||||
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
|
search=query,
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
||||||
@@ -582,7 +586,11 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
# Tokenize input
|
# Tokenize input
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
|
formatted_prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move inputs to device
|
# Move inputs to device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ElementTree
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ def get_safe_path(s: str) -> str:
|
|||||||
def process_history(history: str):
|
def process_history(history: str):
|
||||||
if history.startswith("<?xml") or history.startswith("<msg>"):
|
if history.startswith("<?xml") or history.startswith("<msg>"):
|
||||||
try:
|
try:
|
||||||
root = ET.fromstring(history)
|
root = ElementTree.fromstring(history)
|
||||||
title = root.find(".//title").text if root.find(".//title") is not None else None
|
title = root.find(".//title").text if root.find(".//title") is not None else None
|
||||||
quoted = (
|
quoted = (
|
||||||
root.find(".//refermsg/content").text
|
root.find(".//refermsg/content").text
|
||||||
@@ -52,7 +52,8 @@ def get_message(history: dict | str):
|
|||||||
|
|
||||||
def export_chathistory(user_id: str):
|
def export_chathistory(user_id: str):
|
||||||
res = requests.get(
|
res = requests.get(
|
||||||
"http://localhost:48065/wechat/chatlog", params={"userId": user_id, "count": 100000}
|
"http://localhost:48065/wechat/chatlog",
|
||||||
|
params={"userId": user_id, "count": 100000},
|
||||||
).json()
|
).json()
|
||||||
for i in range(len(res["chatLogs"])):
|
for i in range(len(res["chatLogs"])):
|
||||||
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
||||||
@@ -116,7 +117,8 @@ def export_sqlite(
|
|||||||
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||||
for user in tqdm(all_users):
|
for user in tqdm(all_users):
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user["arg"], user["title"])
|
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
|
||||||
|
(user["arg"], user["title"]),
|
||||||
)
|
)
|
||||||
usr_chatlog = export_chathistory(user["arg"])
|
usr_chatlog = export_chathistory(user["arg"])
|
||||||
for msg in usr_chatlog:
|
for msg in usr_chatlog:
|
||||||
|
|||||||
@@ -58,7 +58,8 @@ class GraphWrapper:
|
|||||||
self.graph = torch.cuda.CUDAGraph()
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph):
|
with torch.cuda.graph(self.graph):
|
||||||
self.static_output = self.model(
|
self.static_output = self.model(
|
||||||
input_ids=self.static_input, attention_mask=self.static_attention_mask
|
input_ids=self.static_input,
|
||||||
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
self.use_cuda_graph = True
|
self.use_cuda_graph = True
|
||||||
else:
|
else:
|
||||||
@@ -82,7 +83,10 @@ class GraphWrapper:
|
|||||||
def _warmup(self, num_warmup: int = 3):
|
def _warmup(self, num_warmup: int = 3):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(num_warmup):
|
for _ in range(num_warmup):
|
||||||
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask)
|
self.model(
|
||||||
|
input_ids=self.static_input,
|
||||||
|
attention_mask=self.static_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
if self.use_cuda_graph:
|
if self.use_cuda_graph:
|
||||||
@@ -261,7 +265,10 @@ class Benchmark:
|
|||||||
# print size
|
# print size
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
new_module = bnb.nn.Linear8bitLt(
|
||||||
in_features, out_features, bias=bias, has_fp16_weights=False
|
in_features,
|
||||||
|
out_features,
|
||||||
|
bias=bias,
|
||||||
|
has_fp16_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy weights and bias
|
# Copy weights and bias
|
||||||
@@ -350,8 +357,6 @@ class Benchmark:
|
|||||||
# Try xformers if available (only on CUDA)
|
# Try xformers if available (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
|
||||||
|
|
||||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
@@ -427,7 +432,11 @@ class Benchmark:
|
|||||||
else "cpu"
|
else "cpu"
|
||||||
)
|
)
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long
|
0,
|
||||||
|
1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(
|
def _run_inference(
|
||||||
|
|||||||
@@ -115,7 +115,13 @@ def main():
|
|||||||
# --- Plotting ---
|
# --- Plotting ---
|
||||||
print("\n--- Generating Plot ---")
|
print("\n--- Generating Plot ---")
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
|
plt.plot(
|
||||||
|
BATCH_SIZES,
|
||||||
|
results_torch,
|
||||||
|
marker="o",
|
||||||
|
linestyle="-",
|
||||||
|
label=f"PyTorch ({device})",
|
||||||
|
)
|
||||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||||
|
|
||||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||||
|
|||||||
@@ -170,7 +170,11 @@ class Benchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, self.config.seq_length), device=self.device, dtype=torch.long
|
0,
|
||||||
|
1000,
|
||||||
|
(batch_size, self.config.seq_length),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
@@ -256,7 +260,11 @@ def run_mlx_benchmark():
|
|||||||
"""Run MLX-specific benchmark"""
|
"""Run MLX-specific benchmark"""
|
||||||
if not MLX_AVAILABLE:
|
if not MLX_AVAILABLE:
|
||||||
print("MLX not available, skipping MLX benchmark")
|
print("MLX not available, skipping MLX benchmark")
|
||||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "MLX not available"}
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "MLX not available",
|
||||||
|
}
|
||||||
|
|
||||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||||
|
|
||||||
@@ -265,7 +273,11 @@ def run_mlx_benchmark():
|
|||||||
results = benchmark.run()
|
results = benchmark.run()
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "No valid results"}
|
return {
|
||||||
|
"max_throughput": 0.0,
|
||||||
|
"avg_throughput": 0.0,
|
||||||
|
"error": "No valid results",
|
||||||
|
}
|
||||||
|
|
||||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||||
|
|||||||
Reference in New Issue
Block a user