feat: snapshot
This commit is contained in:
196
git_helper.py
196
git_helper.py
@@ -2,23 +2,39 @@ import sys
|
||||
import os
|
||||
import git
|
||||
import configparser
|
||||
import re
|
||||
import json
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
config_path = os.path.join(os.path.dirname(__file__), "config.ini")
|
||||
nodelist_path = os.path.join(os.path.dirname(__file__), "custom-node-list.json")
|
||||
working_directory = os.getcwd()
|
||||
|
||||
def gitclone(custom_nodes_path, url):
|
||||
|
||||
def gitclone(custom_nodes_path, url, target_hash=None):
|
||||
repo_name = os.path.splitext(os.path.basename(url))[0]
|
||||
repo_path = os.path.join(custom_nodes_path, repo_name)
|
||||
|
||||
# Clone the repository from the remote URL
|
||||
repo = git.Repo.clone_from(url, repo_path, recursive=True)
|
||||
|
||||
if target_hash is not None:
|
||||
print(f"CHECKOUT: {repo_name} [{target_hash}]")
|
||||
repo.git.checkout(target_hash)
|
||||
|
||||
repo.git.clear_cache()
|
||||
repo.close()
|
||||
|
||||
|
||||
def gitcheck(path, do_fetch=False):
|
||||
try:
|
||||
# Fetch the latest commits from the remote repository
|
||||
repo = git.Repo(path)
|
||||
|
||||
if repo.head.is_detached:
|
||||
print("CUSTOM NODE CHECK: True")
|
||||
return
|
||||
|
||||
current_branch = repo.active_branch
|
||||
branch_name = current_branch.name
|
||||
|
||||
@@ -48,6 +64,14 @@ def gitcheck(path, do_fetch=False):
|
||||
print("CUSTOM NODE CHECK: Error")
|
||||
|
||||
|
||||
def switch_to_default_branch(repo):
|
||||
show_result = repo.git.remote("show", "origin")
|
||||
matches = re.search(r"\s*HEAD branch:\s*(.*)", show_result)
|
||||
if matches:
|
||||
default_branch = matches.group(1)
|
||||
repo.git.checkout(default_branch)
|
||||
|
||||
|
||||
def gitpull(path):
|
||||
# Check if the path is a git repository
|
||||
if not os.path.exists(os.path.join(path, '.git')):
|
||||
@@ -60,8 +84,12 @@ def gitpull(path):
|
||||
|
||||
commit_hash = repo.head.commit.hexsha
|
||||
try:
|
||||
if repo.head.is_detached:
|
||||
switch_to_default_branch(repo)
|
||||
|
||||
origin = repo.remote(name='origin')
|
||||
origin.pull(rebase=True)
|
||||
origin.pull()
|
||||
|
||||
repo.git.submodule('update', '--init', '--recursive')
|
||||
new_commit_hash = repo.head.commit.hexsha
|
||||
|
||||
@@ -76,6 +104,165 @@ def gitpull(path):
|
||||
repo.close()
|
||||
|
||||
|
||||
def checkout_comfyui_hash(target_hash):
|
||||
repo_path = os.path.join(working_directory, '..') # ComfyUI dir
|
||||
|
||||
repo = git.Repo(repo_path)
|
||||
commit_hash = repo.head.commit.hexsha
|
||||
|
||||
if commit_hash != target_hash:
|
||||
try:
|
||||
print(f"CHECKOUT: ComfyUI [{target_hash}]")
|
||||
repo.git.checkout(target_hash)
|
||||
except git.GitCommandError as e:
|
||||
print(f"Error checking out the ComfyUI: {str(e)}")
|
||||
|
||||
|
||||
def checkout_custom_node_hash(git_custom_node_infos):
|
||||
repo_name_to_url = {}
|
||||
|
||||
for url in git_custom_node_infos.keys():
|
||||
repo_name = url.split('/')[-1]
|
||||
|
||||
if repo_name.endswith('.git'):
|
||||
repo_name = repo_name[:-4]
|
||||
|
||||
repo_name_to_url[repo_name] = url
|
||||
|
||||
for path in os.listdir(working_directory):
|
||||
if path.endswith("ComfyUI-Manager"):
|
||||
continue
|
||||
|
||||
fullpath = os.path.join(working_directory, path)
|
||||
|
||||
if os.path.isdir(fullpath):
|
||||
is_disabled = path.endswith(".disabled")
|
||||
|
||||
try:
|
||||
git_dir = os.path.join(fullpath, '.git')
|
||||
if not os.path.exists(git_dir):
|
||||
continue
|
||||
|
||||
need_checkout = False
|
||||
repo_name = os.path.basename(fullpath)
|
||||
|
||||
if repo_name.endswith('.disabled'):
|
||||
repo_name = repo_name[:-9]
|
||||
|
||||
item = git_custom_node_infos[repo_name_to_url[repo_name]]
|
||||
if item['disabled'] and is_disabled:
|
||||
pass
|
||||
elif item['disabled'] and not is_disabled:
|
||||
# disable
|
||||
print(f"DISABLE: {repo_name}")
|
||||
new_path = fullpath + ".disabled"
|
||||
os.rename(fullpath, new_path)
|
||||
pass
|
||||
elif not item['disabled'] and is_disabled:
|
||||
# enable
|
||||
print(f"ENABLE: {repo_name}")
|
||||
new_path = fullpath[:-9]
|
||||
os.rename(fullpath, new_path)
|
||||
fullpath = new_path
|
||||
need_checkout = True
|
||||
else:
|
||||
need_checkout = True
|
||||
|
||||
if need_checkout:
|
||||
repo = git.Repo(fullpath)
|
||||
commit_hash = repo.head.commit.hexsha
|
||||
|
||||
if commit_hash != item['hash']:
|
||||
print(f"CHECKOUT: {repo_name} [{item['hash']}]")
|
||||
repo.git.checkout(item['hash'])
|
||||
except Exception:
|
||||
print(f"Failed to restore snapshots for the custom node '{path}'")
|
||||
|
||||
# clone missing
|
||||
for k, v in git_custom_node_infos.items():
|
||||
if not v['disabled']:
|
||||
repo_name = k.split('/')[-1]
|
||||
if repo_name.endswith('.git'):
|
||||
repo_name = repo_name[:-4]
|
||||
|
||||
path = os.path.join(working_directory, repo_name)
|
||||
if not os.path.exists(path):
|
||||
print(f"CLONE: {path}")
|
||||
gitclone(working_directory, k, v['hash'])
|
||||
|
||||
|
||||
def invalidate_custom_node_file(file_custom_node_infos):
|
||||
global nodelist_path
|
||||
|
||||
enabled_set = set()
|
||||
for item in file_custom_node_infos:
|
||||
if not item['disabled']:
|
||||
enabled_set.add(item['filename'])
|
||||
|
||||
for path in os.listdir(working_directory):
|
||||
fullpath = os.path.join(working_directory, path)
|
||||
|
||||
if not os.path.isdir(fullpath) and fullpath.endswith('.py'):
|
||||
if path not in enabled_set:
|
||||
print(f"DISABLE: {path}")
|
||||
new_path = fullpath+'.disabled'
|
||||
os.rename(fullpath, new_path)
|
||||
|
||||
elif not os.path.isdir(fullpath) and fullpath.endswith('.py.disabled'):
|
||||
path = path[:-9]
|
||||
if path in enabled_set:
|
||||
print(f"ENABLE: {path}")
|
||||
new_path = fullpath[:-9]
|
||||
os.rename(fullpath, new_path)
|
||||
|
||||
# download missing: just support for 'copy' style
|
||||
py_to_url = {}
|
||||
|
||||
with open(nodelist_path, 'r', encoding="UTF-8") as json_file:
|
||||
info = json.load(json_file)
|
||||
for item in info['custom_nodes']:
|
||||
if item['install_type'] == 'copy':
|
||||
for url in item['files']:
|
||||
if url.endswith('.py'):
|
||||
py = url.split('/')[-1]
|
||||
py_to_url[py] = url
|
||||
|
||||
for item in file_custom_node_infos:
|
||||
filename = item['filename']
|
||||
if not item['disabled']:
|
||||
target_path = os.path.join(working_directory, filename)
|
||||
|
||||
if not os.path.exists(target_path) and filename in py_to_url:
|
||||
url = py_to_url[filename]
|
||||
print(f"DOWNLOAD: {filename}")
|
||||
download_url(url, working_directory)
|
||||
|
||||
|
||||
def apply_snapshot(target):
|
||||
try:
|
||||
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{target}")
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r', encoding="UTF-8") as json_file:
|
||||
info = json.load(json_file)
|
||||
|
||||
comfyui_hash = info['comfyui']
|
||||
git_custom_node_infos = info['git_custom_nodes']
|
||||
file_custom_node_infos = info['file_custom_nodes']
|
||||
|
||||
checkout_comfyui_hash(comfyui_hash)
|
||||
checkout_custom_node_hash(git_custom_node_infos)
|
||||
invalidate_custom_node_file(file_custom_node_infos)
|
||||
|
||||
print("APPLY SNAPSHOT: True")
|
||||
return
|
||||
|
||||
print(f"Snapshot file not found: `{path}`")
|
||||
print("APPLY SNAPSHOT: False")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("APPLY SNAPSHOT: False")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
@@ -95,8 +282,11 @@ try:
|
||||
gitcheck(sys.argv[2], True)
|
||||
elif sys.argv[1] == "--pull":
|
||||
gitpull(sys.argv[2])
|
||||
elif sys.argv[1] == "--apply-snapshot":
|
||||
apply_snapshot(sys.argv[2])
|
||||
sys.exit(0)
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user