1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add model downloading endpoint. (#4248)

* Add model downloading endpoint.

* Move client session init to async function.

* Break up large function.

* Send "download_progress" as websocket event.

* Fixed

* Fixed.

* Use async mock.

* Move server set up to right before run call.

* Validate that model subdirectory cannot contain relative paths.

* Add download_model test checking for invalid paths.

* Remove DS_Store.

* Consolidate DownloadStatus and DownloadModelResult

* Add progress_interval as an optional parameter.

* Use tuple type from annotations.

* Use pydantic.

* Update comment.

* Revert "Use pydantic."

This reverts commit 7461e8eb00.

* Add new line.

* Add newline EOF.

* Validate model filename as well.

* Add comment to not reply on internal.

* Restrict downloading to safetensor files only.
This commit is contained in:
Robin Huang
2024-08-13 12:48:52 -07:00
committed by GitHub
parent 34608de2e9
commit 3e52e0364c
7 changed files with 599 additions and 2 deletions

View File

@@ -12,7 +12,6 @@ import json
import glob
import struct
import ssl
import hashlib
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@@ -28,7 +27,8 @@ import comfy.model_management
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@@ -76,6 +76,7 @@ class PromptServer():
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0
middlewares = [cache_control]
@@ -559,6 +560,36 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
await self.send_json("download_progress", status.to_dict())
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def add_routes(self):
self.user_manager.add_routes(self.routes)