feat: snapshot

This commit is contained in:
Dr.Lt.Data
2023-11-04 17:17:55 +09:00
parent 2d6633ec6c
commit 7fbe34f8db
10 changed files with 2518 additions and 1742 deletions

View File

@@ -4,7 +4,8 @@ import folder_paths
import os
import sys
import threading
import subprocess
import datetime
import re
def handle_stream(stream, prefix):
@@ -56,7 +57,7 @@ sys.path.append('../..')
from torchvision.datasets.utils import download_url
# ensure .js
print("### Loading: ComfyUI-Manager (V0.36.1)")
print("### Loading: ComfyUI-Manager (V0.37)")
comfy_ui_required_revision = 1240
comfy_ui_revision = "Unknown"
@@ -283,6 +284,14 @@ def __win_check_git_pull(path):
process.wait()
def switch_to_default_branch(repo):
show_result = repo.git.remote("show", "origin")
matches = re.search(r"\s*HEAD branch:\s*(.*)", show_result)
if matches:
default_branch = matches.group(1)
repo.git.checkout(default_branch)
def git_repo_has_updates(path, do_fetch=False, do_update=False):
if do_fetch:
print(f"\x1b[2K\rFetching: {path}", end='')
@@ -299,9 +308,6 @@ def git_repo_has_updates(path, do_fetch=False, do_update=False):
# Fetch the latest commits from the remote repository
repo = git.Repo(path)
current_branch = repo.active_branch
branch_name = current_branch.name
remote_name = 'origin'
remote = repo.remote(name=remote_name)
@@ -312,8 +318,11 @@ def git_repo_has_updates(path, do_fetch=False, do_update=False):
remote.fetch()
if do_update:
if repo.head.is_detached:
switch_to_default_branch(repo)
try:
remote.pull(rebase=True)
remote.pull()
repo.git.submodule('update', '--init', '--recursive')
new_commit_hash = repo.head.commit.hexsha
@@ -326,7 +335,13 @@ def git_repo_has_updates(path, do_fetch=False, do_update=False):
except Exception as e:
print(f"\nUpdating failed: {path}\n{e}", file=sys.stderr)
if repo.head.is_detached:
return True
# Get commit hash of the remote branch
current_branch = repo.active_branch
branch_name = current_branch.name
remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha
# Compare the commit hashes to determine if the local repository is behind the remote repository
@@ -352,11 +367,17 @@ def git_pull(path):
return __win_check_git_pull(path)
else:
repo = git.Repo(path)
print(f"path={path} / repo.is_dirty: {repo.is_dirty()}")
if repo.is_dirty():
repo.git.stash()
if repo.head.is_detached:
switch_to_default_branch(repo)
origin = repo.remote(name='origin')
origin.pull(rebase=True)
origin.pull()
repo.git.submodule('update', '--init', '--recursive')
repo.close()
@@ -569,6 +590,8 @@ async def fetch_updates(request):
@server.PromptServer.instance.routes.get("/customnode/update_all")
async def update_all(request):
try:
save_snapshot_with_postfix('autosave')
if request.rel_url.query["mode"] == "local":
uri = local_db_custom_node_list
else:
@@ -663,6 +686,125 @@ async def fetch_externalmodel_list(request):
return web.json_response(json_obj, content_type='application/json')
@server.PromptServer.instance.routes.get("/snapshot/getlist")
async def get_snapshot_list(request):
snapshots_directory = os.path.join(os.path.dirname(__file__), 'snapshots')
items = [f[:-5] for f in os.listdir(snapshots_directory) if f.endswith('.json')]
items.sort(reverse=True)
return web.json_response({'items': items}, content_type='application/json')
@server.PromptServer.instance.routes.get("/snapshot/remove")
async def remove_snapshot(request):
try:
target = request.rel_url.query["target"]
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{target}.json")
if os.path.exists(path):
os.remove(path)
return web.Response(status=200)
except:
return web.Response(status=400)
@server.PromptServer.instance.routes.get("/snapshot/restore")
async def remove_snapshot(request):
try:
target = request.rel_url.query["target"]
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{target}.json")
if os.path.exists(path):
if not os.path.exists(startup_script_path):
os.makedirs(startup_script_path)
target_path = os.path.join(startup_script_path, "restore-snapshot.json")
shutil.copy(path, target_path)
print(f"Snapshot restore scheduled: `{target}`")
return web.Response(status=200)
print(f"Snapshot file not found: `{path}`")
return web.Response(status=400)
except:
return web.Response(status=400)
def get_current_snapshot():
# Get ComfyUI hash
repo_path = os.path.dirname(folder_paths.__file__)
if not os.path.exists(os.path.join(repo_path, '.git')):
print(f"ComfyUI update fail: The installed ComfyUI does not have a Git repository.")
return web.Response(status=400)
repo = git.Repo(repo_path)
comfyui_commit_hash = repo.head.commit.hexsha
git_custom_nodes = {}
file_custom_nodes = []
# Get custom nodes hash
for path in os.listdir(custom_nodes_path):
fullpath = os.path.join(custom_nodes_path, path)
if os.path.isdir(fullpath):
is_disabled = path.endswith(".disabled")
try:
git_dir = os.path.join(fullpath, '.git')
if not os.path.exists(git_dir):
continue
repo = git.Repo(fullpath)
commit_hash = repo.head.commit.hexsha
url = repo.remotes.origin.url
git_custom_nodes[url] = {
'hash': commit_hash,
'disabled': is_disabled
}
except:
print(f"Failed to extract snapshots for the custom node '{path}'.")
elif path.endswith('.py'):
is_disabled = path.endswith(".py.disabled")
filename = os.path.basename(path)
item = {
'filename': filename,
'disabled': is_disabled
}
file_custom_nodes.append(item)
return {
'comfyui': comfyui_commit_hash,
'git_custom_nodes': git_custom_nodes,
'file_custom_nodes': file_custom_nodes,
}
def save_snapshot_with_postfix(postfix):
now = datetime.datetime.now()
date_time_format = now.strftime("%Y-%m-%d_%H-%M-%S")
file_name = f"{date_time_format}_{postfix}"
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{file_name}.json")
with open(path, "w") as json_file:
json.dump(get_current_snapshot(), json_file, indent=4)
@server.PromptServer.instance.routes.get("/snapshot/save")
async def save_snapshot(request):
try:
save_snapshot_with_postfix('snapshot')
return web.Response(status=200)
except:
return web.Response(status=400)
def unzip_install(files):
temp_filename = 'manager-temp.zip'
for url in files:
@@ -1073,6 +1215,9 @@ async def update_comfyui(request):
# version check
repo = git.Repo(repo_path)
if repo.head.is_detached:
switch_to_default_branch(repo)
current_branch = repo.active_branch
branch_name = current_branch.name
@@ -1202,6 +1347,8 @@ async def channel_url_list(request):
return web.Response(status=200)
WEB_DIRECTORY = "js"
NODE_CLASS_MAPPINGS = {}
__all__ = ['NODE_CLASS_MAPPINGS']