merged updates from upstream, and following output types work: SaveImage, PreviewImage, AnimateDiffCombine, VideoCombine
This commit is contained in:
228
__init__.py
228
__init__.py
@@ -5,12 +5,31 @@ import folder_paths
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import subprocess
|
||||
import datetime
|
||||
import re
|
||||
import locale
|
||||
import subprocess # don't remove this
|
||||
from tqdm.auto import tqdm
|
||||
import concurrent
|
||||
import ssl
|
||||
|
||||
version = "V0.41"
|
||||
print(f"### Loading: ComfyUI-Manager ({version})")
|
||||
|
||||
|
||||
def handle_stream(stream, prefix):
|
||||
for line in stream:
|
||||
print(prefix, line, end="")
|
||||
stream.reconfigure(encoding=locale.getpreferredencoding(), errors='replace')
|
||||
for msg in stream:
|
||||
if prefix == '[!]' and ('it/s]' in msg or 's/it]' in msg) and ('%|' in msg or 'it [' in msg):
|
||||
if msg.startswith('100%'):
|
||||
print('\r' + msg, end="", file=sys.stderr),
|
||||
else:
|
||||
print('\r' + msg[:-1], end="", file=sys.stderr),
|
||||
else:
|
||||
if prefix == '[!]':
|
||||
print(prefix, msg, end="", file=sys.stderr)
|
||||
else:
|
||||
print(prefix, msg, end="")
|
||||
|
||||
|
||||
def run_script(cmd, cwd='.'):
|
||||
@@ -52,13 +71,12 @@ except:
|
||||
print(f"## ComfyUI-Manager: installing dependencies done.")
|
||||
|
||||
|
||||
from git.remote import RemoteProgress
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
# ensure .js
|
||||
print("### Loading: ComfyUI-Manager (V0.36)")
|
||||
|
||||
comfy_ui_required_revision = 1240
|
||||
comfy_ui_revision = "Unknown"
|
||||
|
||||
@@ -95,7 +113,8 @@ def write_config():
|
||||
'badge_mode': get_config()['badge_mode'],
|
||||
'git_exe': get_config()['git_exe'],
|
||||
'channel_url': get_config()['channel_url'],
|
||||
'channel_url_list': get_config()['channel_url_list']
|
||||
'channel_url_list': get_config()['channel_url_list'],
|
||||
'bypass_ssl': get_config()['bypass_ssl']
|
||||
}
|
||||
with open(config_path, 'w') as configfile:
|
||||
config.write(configfile)
|
||||
@@ -125,7 +144,8 @@ def read_config():
|
||||
'badge_mode': default_conf['badge_mode'] if 'badge_mode' in default_conf else 'none',
|
||||
'git_exe': default_conf['git_exe'] if 'git_exe' in default_conf else '',
|
||||
'channel_url': default_conf['channel_url'] if 'channel_url' in default_conf else 'https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main',
|
||||
'channel_url_list': ch_url_list
|
||||
'channel_url_list': ch_url_list,
|
||||
'bypass_ssl': default_conf['bypass_ssl'] if 'bypass_ssl' in default_conf else False,
|
||||
}
|
||||
|
||||
except Exception:
|
||||
@@ -134,7 +154,8 @@ def read_config():
|
||||
'badge_mode': 'none',
|
||||
'git_exe': '',
|
||||
'channel_url': 'https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main',
|
||||
'channel_url_list': ''
|
||||
'channel_url_list': '',
|
||||
'bypass_ssl': False
|
||||
}
|
||||
|
||||
|
||||
@@ -284,6 +305,14 @@ def __win_check_git_pull(path):
|
||||
process.wait()
|
||||
|
||||
|
||||
def switch_to_default_branch(repo):
|
||||
show_result = repo.git.remote("show", "origin")
|
||||
matches = re.search(r"\s*HEAD branch:\s*(.*)", show_result)
|
||||
if matches:
|
||||
default_branch = matches.group(1)
|
||||
repo.git.checkout(default_branch)
|
||||
|
||||
|
||||
def git_repo_has_updates(path, do_fetch=False, do_update=False):
|
||||
if do_fetch:
|
||||
print(f"\x1b[2K\rFetching: {path}", end='')
|
||||
@@ -300,9 +329,6 @@ def git_repo_has_updates(path, do_fetch=False, do_update=False):
|
||||
# Fetch the latest commits from the remote repository
|
||||
repo = git.Repo(path)
|
||||
|
||||
current_branch = repo.active_branch
|
||||
branch_name = current_branch.name
|
||||
|
||||
remote_name = 'origin'
|
||||
remote = repo.remote(name=remote_name)
|
||||
|
||||
@@ -313,8 +339,11 @@ def git_repo_has_updates(path, do_fetch=False, do_update=False):
|
||||
remote.fetch()
|
||||
|
||||
if do_update:
|
||||
if repo.head.is_detached:
|
||||
switch_to_default_branch(repo)
|
||||
|
||||
try:
|
||||
remote.pull(rebase=True)
|
||||
remote.pull()
|
||||
repo.git.submodule('update', '--init', '--recursive')
|
||||
new_commit_hash = repo.head.commit.hexsha
|
||||
|
||||
@@ -327,7 +356,13 @@ def git_repo_has_updates(path, do_fetch=False, do_update=False):
|
||||
except Exception as e:
|
||||
print(f"\nUpdating failed: {path}\n{e}", file=sys.stderr)
|
||||
|
||||
if repo.head.is_detached:
|
||||
return True
|
||||
|
||||
# Get commit hash of the remote branch
|
||||
current_branch = repo.active_branch
|
||||
branch_name = current_branch.name
|
||||
|
||||
remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha
|
||||
|
||||
# Compare the commit hashes to determine if the local repository is behind the remote repository
|
||||
@@ -353,11 +388,17 @@ def git_pull(path):
|
||||
return __win_check_git_pull(path)
|
||||
else:
|
||||
repo = git.Repo(path)
|
||||
|
||||
print(f"path={path} / repo.is_dirty: {repo.is_dirty()}")
|
||||
|
||||
if repo.is_dirty():
|
||||
repo.git.stash()
|
||||
|
||||
if repo.head.is_detached:
|
||||
switch_to_default_branch(repo)
|
||||
|
||||
origin = repo.remote(name='origin')
|
||||
origin.pull(rebase=True)
|
||||
origin.pull()
|
||||
repo.git.submodule('update', '--init', '--recursive')
|
||||
|
||||
repo.close()
|
||||
@@ -518,9 +559,13 @@ def check_custom_nodes_installed(json_obj, do_fetch=False, do_update_check=True,
|
||||
elif do_update_check:
|
||||
print("Start update check...", end="")
|
||||
|
||||
for item in json_obj['custom_nodes']:
|
||||
def process_custom_node(item):
|
||||
check_a_custom_node_installed(item, do_fetch, do_update_check, do_update)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(4) as executor:
|
||||
for item in json_obj['custom_nodes']:
|
||||
executor.submit(process_custom_node, item)
|
||||
|
||||
if do_fetch:
|
||||
print(f"\x1b[2K\rFetching done.")
|
||||
elif do_update:
|
||||
@@ -569,6 +614,8 @@ async def fetch_updates(request):
|
||||
@server.PromptServer.instance.routes.get("/customnode/update_all")
|
||||
async def update_all(request):
|
||||
try:
|
||||
save_snapshot_with_postfix('autosave')
|
||||
|
||||
if request.rel_url.query["mode"] == "local":
|
||||
uri = local_db_custom_node_list
|
||||
else:
|
||||
@@ -638,10 +685,9 @@ async def fetch_alternatives_list(request):
|
||||
|
||||
|
||||
def check_model_installed(json_obj):
|
||||
for item in json_obj['models']:
|
||||
item['installed'] = 'None'
|
||||
|
||||
def process_model(item):
|
||||
model_path = get_model_path(item)
|
||||
item['installed'] = 'None'
|
||||
|
||||
if model_path is not None:
|
||||
if os.path.exists(model_path):
|
||||
@@ -649,6 +695,10 @@ def check_model_installed(json_obj):
|
||||
else:
|
||||
item['installed'] = 'False'
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(8) as executor:
|
||||
for item in json_obj['models']:
|
||||
executor.submit(process_model, item)
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/externalmodel/getlist")
|
||||
async def fetch_externalmodel_list(request):
|
||||
@@ -663,6 +713,125 @@ async def fetch_externalmodel_list(request):
|
||||
return web.json_response(json_obj, content_type='application/json')
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/snapshot/getlist")
|
||||
async def get_snapshot_list(request):
|
||||
snapshots_directory = os.path.join(os.path.dirname(__file__), 'snapshots')
|
||||
items = [f[:-5] for f in os.listdir(snapshots_directory) if f.endswith('.json')]
|
||||
items.sort(reverse=True)
|
||||
return web.json_response({'items': items}, content_type='application/json')
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/snapshot/remove")
|
||||
async def remove_snapshot(request):
|
||||
try:
|
||||
target = request.rel_url.query["target"]
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{target}.json")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
return web.Response(status=200)
|
||||
except:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/snapshot/restore")
|
||||
async def remove_snapshot(request):
|
||||
try:
|
||||
target = request.rel_url.query["target"]
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{target}.json")
|
||||
if os.path.exists(path):
|
||||
if not os.path.exists(startup_script_path):
|
||||
os.makedirs(startup_script_path)
|
||||
|
||||
target_path = os.path.join(startup_script_path, "restore-snapshot.json")
|
||||
shutil.copy(path, target_path)
|
||||
|
||||
print(f"Snapshot restore scheduled: `{target}`")
|
||||
return web.Response(status=200)
|
||||
|
||||
print(f"Snapshot file not found: `{path}`")
|
||||
return web.Response(status=400)
|
||||
except:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
def get_current_snapshot():
|
||||
# Get ComfyUI hash
|
||||
repo_path = os.path.dirname(folder_paths.__file__)
|
||||
|
||||
if not os.path.exists(os.path.join(repo_path, '.git')):
|
||||
print(f"ComfyUI update fail: The installed ComfyUI does not have a Git repository.")
|
||||
return web.Response(status=400)
|
||||
|
||||
repo = git.Repo(repo_path)
|
||||
comfyui_commit_hash = repo.head.commit.hexsha
|
||||
|
||||
git_custom_nodes = {}
|
||||
file_custom_nodes = []
|
||||
|
||||
# Get custom nodes hash
|
||||
for path in os.listdir(custom_nodes_path):
|
||||
fullpath = os.path.join(custom_nodes_path, path)
|
||||
|
||||
if os.path.isdir(fullpath):
|
||||
is_disabled = path.endswith(".disabled")
|
||||
|
||||
try:
|
||||
git_dir = os.path.join(fullpath, '.git')
|
||||
|
||||
if not os.path.exists(git_dir):
|
||||
continue
|
||||
|
||||
repo = git.Repo(fullpath)
|
||||
commit_hash = repo.head.commit.hexsha
|
||||
url = repo.remotes.origin.url
|
||||
git_custom_nodes[url] = {
|
||||
'hash': commit_hash,
|
||||
'disabled': is_disabled
|
||||
}
|
||||
|
||||
except:
|
||||
print(f"Failed to extract snapshots for the custom node '{path}'.")
|
||||
|
||||
elif path.endswith('.py'):
|
||||
is_disabled = path.endswith(".py.disabled")
|
||||
filename = os.path.basename(path)
|
||||
item = {
|
||||
'filename': filename,
|
||||
'disabled': is_disabled
|
||||
}
|
||||
|
||||
file_custom_nodes.append(item)
|
||||
|
||||
return {
|
||||
'comfyui': comfyui_commit_hash,
|
||||
'git_custom_nodes': git_custom_nodes,
|
||||
'file_custom_nodes': file_custom_nodes,
|
||||
}
|
||||
|
||||
|
||||
def save_snapshot_with_postfix(postfix):
|
||||
now = datetime.datetime.now()
|
||||
|
||||
date_time_format = now.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
file_name = f"{date_time_format}_{postfix}"
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'snapshots', f"{file_name}.json")
|
||||
with open(path, "w") as json_file:
|
||||
json.dump(get_current_snapshot(), json_file, indent=4)
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/snapshot/save")
|
||||
async def save_snapshot(request):
|
||||
try:
|
||||
save_snapshot_with_postfix('snapshot')
|
||||
return web.Response(status=200)
|
||||
except:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
def unzip_install(files):
|
||||
temp_filename = 'manager-temp.zip'
|
||||
for url in files:
|
||||
@@ -810,6 +979,18 @@ def execute_install_script(url, repo_path):
|
||||
return True
|
||||
|
||||
|
||||
class GitProgress(RemoteProgress):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pbar = tqdm()
|
||||
|
||||
def update(self, op_code, cur_count, max_count=None, message=''):
|
||||
self.pbar.total = max_count
|
||||
self.pbar.n = cur_count
|
||||
self.pbar.pos = 0
|
||||
self.pbar.refresh()
|
||||
|
||||
|
||||
def gitclone_install(files):
|
||||
print(f"install: {files}")
|
||||
for url in files:
|
||||
@@ -824,7 +1005,7 @@ def gitclone_install(files):
|
||||
if platform.system() == 'Windows':
|
||||
run_script([sys.executable, git_script_path, "--clone", custom_nodes_path, url])
|
||||
else:
|
||||
repo = git.Repo.clone_from(url, repo_path, recursive=True)
|
||||
repo = git.Repo.clone_from(url, repo_path, recursive=True, progress=GitProgress())
|
||||
repo.git.clear_cache()
|
||||
repo.close()
|
||||
|
||||
@@ -1073,6 +1254,9 @@ async def update_comfyui(request):
|
||||
# version check
|
||||
repo = git.Repo(repo_path)
|
||||
|
||||
if repo.head.is_detached:
|
||||
switch_to_default_branch(repo)
|
||||
|
||||
current_branch = repo.active_branch
|
||||
branch_name = current_branch.name
|
||||
|
||||
@@ -1133,6 +1317,7 @@ async def install_model(request):
|
||||
if json_data['url'].startswith('https://github.com') or json_data['url'].startswith('https://huggingface.co'):
|
||||
model_dir = get_model_dir(json_data)
|
||||
download_url(json_data['url'], model_dir)
|
||||
|
||||
return web.json_response({}, content_type='application/json')
|
||||
else:
|
||||
res = download_url_with_agent(json_data['url'], model_path)
|
||||
@@ -1419,6 +1604,11 @@ async def share_art(request):
|
||||
}
|
||||
}, content_type='application/json', status=200)
|
||||
|
||||
if get_config()['bypass_ssl']:
|
||||
ssl._create_default_https_context = ssl._create_unverified_context # SSL certificate error fix.
|
||||
|
||||
|
||||
WEB_DIRECTORY = "js"
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
__all__ = ['NODE_CLASS_MAPPINGS']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user