1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Give linux some love

This commit is contained in:
reaper47
2023-06-07 15:15:38 +02:00
parent ee62b4ecc2
commit 5cf4079923
2 changed files with 37 additions and 34 deletions

20
main.py
View File

@@ -37,21 +37,25 @@ def prompt_worker(q, server):
e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server):
def hook(value, total, preview_image_bytes):
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
@@ -72,6 +76,7 @@ def load_extra_path_config(yaml_path):
print("Adding extra search path", x, full_path)
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__":
cleanup_temp()
@@ -92,7 +97,7 @@ if __name__ == "__main__":
server.add_routes()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
@@ -106,15 +111,12 @@ if __name__ == "__main__":
if args.auto_launch:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))
webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
pass
else:
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
print("\nStopped server")
cleanup_temp()