1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support preview images embedded in safetensors metadata (#6119)

* Support preview images embedded in safetensors metadata

* Add unit test for safetensors embedded image previews
This commit is contained in:
catboxanon
2024-12-19 17:01:56 -05:00
committed by GitHub
parent 2dda7c11a3
commit 3cacd3fca5
2 changed files with 85 additions and 6 deletions

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
@@ -59,13 +62,13 @@ class ModelFileManager:
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
preview_files = self.get_model_previews(full_filename)
default_preview_file = preview_files[0] if len(preview_files) > 0 else None
if default_preview_file is None or not os.path.isfile(default_preview_file):
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview_file) as img:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
@@ -143,7 +146,7 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str]:
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
@@ -152,8 +155,10 @@ class ModelFileManager:
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str] = []
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
@@ -161,6 +166,18 @@ class ModelFileManager:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):