fix: format colqwen_forward.py to pass pre-commit checks
This commit is contained in:
@@ -8,10 +8,9 @@ from pathlib import Path
|
|||||||
# Add the current directory to path to import leann_multi_vector
|
# Add the current directory to path to import leann_multi_vector
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
|
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||||
from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable
|
from PIL import Image
|
||||||
|
|
||||||
# Ensure repo paths are importable
|
# Ensure repo paths are importable
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
@@ -23,7 +22,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|||||||
def create_test_image():
|
def create_test_image():
|
||||||
"""Create a simple test image."""
|
"""Create a simple test image."""
|
||||||
# Create a simple RGB image (800x600)
|
# Create a simple RGB image (800x600)
|
||||||
img = Image.new('RGB', (800, 600), color='white')
|
img = Image.new("RGB", (800, 600), color="white")
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
@@ -42,8 +41,8 @@ def load_test_image_from_file():
|
|||||||
for img_dir in possible_paths:
|
for img_dir in possible_paths:
|
||||||
if img_dir.exists():
|
if img_dir.exists():
|
||||||
# Find first image file
|
# Find first image file
|
||||||
for ext in ['.png', '.jpg', '.jpeg']:
|
for ext in [".png", ".jpg", ".jpeg"]:
|
||||||
for img_file in img_dir.glob(f'*{ext}'):
|
for img_file in img_dir.glob(f"*{ext}"):
|
||||||
print(f"Loading test image from: {img_file}")
|
print(f"Loading test image from: {img_file}")
|
||||||
return Image.open(img_file)
|
return Image.open(img_file)
|
||||||
|
|
||||||
@@ -65,8 +64,8 @@ def main():
|
|||||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||||
|
|
||||||
# Convert to RGB if needed
|
# Convert to RGB if needed
|
||||||
if test_image.mode != 'RGB':
|
if test_image.mode != "RGB":
|
||||||
test_image = test_image.convert('RGB')
|
test_image = test_image.convert("RGB")
|
||||||
print(f"✓ Converted to RGB: {test_image.size}")
|
print(f"✓ Converted to RGB: {test_image.size}")
|
||||||
|
|
||||||
# Step 2: Load model
|
# Step 2: Load model
|
||||||
@@ -77,14 +76,15 @@ def main():
|
|||||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||||
|
|
||||||
# Print model info
|
# Print model info
|
||||||
if hasattr(model, 'device'):
|
if hasattr(model, "device"):
|
||||||
print(f"✓ Model device: {model.device}")
|
print(f"✓ Model device: {model.device}")
|
||||||
if hasattr(model, 'dtype'):
|
if hasattr(model, "dtype"):
|
||||||
print(f"✓ Model dtype: {model.dtype}")
|
print(f"✓ Model dtype: {model.dtype}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error loading model: {e}")
|
print(f"✗ Error loading model: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -97,14 +97,14 @@ def main():
|
|||||||
|
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
|
||||||
print(f"✓ Forward pass completed!")
|
print("✓ Forward pass completed!")
|
||||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||||
|
|
||||||
if len(doc_vecs) > 0:
|
if len(doc_vecs) > 0:
|
||||||
emb = doc_vecs[0]
|
emb = doc_vecs[0]
|
||||||
print(f"✓ Embedding shape: {emb.shape}")
|
print(f"✓ Embedding shape: {emb.shape}")
|
||||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||||
print(f"✓ Embedding stats:")
|
print("✓ Embedding stats:")
|
||||||
print(f" - Min: {emb.min().item():.4f}")
|
print(f" - Min: {emb.min().item():.4f}")
|
||||||
print(f" - Max: {emb.max().item():.4f}")
|
print(f" - Max: {emb.max().item():.4f}")
|
||||||
print(f" - Mean: {emb.mean().item():.4f}")
|
print(f" - Mean: {emb.mean().item():.4f}")
|
||||||
@@ -119,6 +119,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error during forward pass: {e}")
|
print(f"✗ Error during forward pass: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -129,4 +130,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user