1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 23:49:57 +08:00

Compare commits

..

13 Commits

Author SHA1 Message Date
comfyanonymous
619b8cde74 Bump ComfyUI version to 0.3.11 2025-01-16 14:54:48 -05:00
comfyanonymous
31831e6ef1 Code refactor. 2025-01-16 07:23:54 -05:00
comfyanonymous
88ceb28e20 Tweak hunyuan memory usage factor. 2025-01-16 06:31:03 -05:00
comfyanonymous
23289a6a5c Clean up some debug lines. 2025-01-16 04:24:39 -05:00
comfyanonymous
9d8b6c1f46 More accurate memory estimation for cosmos and hunyuan video. 2025-01-16 03:48:40 -05:00
comfyanonymous
6320d05696 Slightly lower hunyuan video memory usage. 2025-01-16 00:23:01 -05:00
comfyanonymous
25683b5b02 Lower cosmos diffusion model memory usage. 2025-01-15 23:46:42 -05:00
comfyanonymous
4758fb64b9 Lower cosmos VAE memory usage by a bit. 2025-01-15 22:57:52 -05:00
comfyanonymous
008761166f Optimize first attention block in cosmos VAE. 2025-01-15 21:48:46 -05:00
comfyanonymous
bfd5dfd611 3.13 doesn't work yet. 2025-01-15 20:32:44 -05:00
comfyanonymous
55ade36d01 Remove python 3.8 from test-build workflow. 2025-01-15 20:24:55 -05:00
comfyanonymous
2e20e399ea Add minimum numpy version to requirements.txt 2025-01-15 20:19:56 -05:00
comfyanonymous
3baf92d120 CosmosImageToVideoLatent batch_size now does something. 2025-01-15 17:19:59 -05:00
29 changed files with 130 additions and 1329 deletions

View File

@@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt

View File

@@ -1,119 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -1,3 +0,0 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

View File

@@ -1,75 +0,0 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,28 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -1,58 +0,0 @@
"""init
Revision ID: 2fb22c4fff36
Revises:
Create Date: 2025-03-27 19:00:47.686079
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '2fb22c4fff36'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('architecture', sa.Text(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
op.create_table('tag',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
op.create_table('model_tag',
sa.Column('model_type', sa.Text(), nullable=False),
sa.Column('model_path', sa.Text(), nullable=False),
sa.Column('tag_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['model_type', 'model_path'], ['model.type', 'model.path'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('model_type', 'model_path', 'tag_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model_tag')
op.drop_table('tag')
op.drop_table('model')
# ### end Alembic commands ###

View File

@@ -1,118 +0,0 @@
import logging
import os
import shutil
import sys
from app.database.models import Tag
from comfy.cli_args import args
try:
import alembic
import sqlalchemy
except ImportError as e:
req_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
)
logging.error(
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
)
exit(-1)
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
Session = None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if current_rev != target_rev:
# Backup the database pre upgrade
db_path = get_db_path()
backup_path = db_path + ".bkp"
if os.path.exists(db_path):
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.error(f"Error upgrading database: {e}")
raise e
global Session
Session = sessionmaker(bind=engine)
if not current_rev:
# Init db, populate models
from app.model_processor import model_processor
session = create_session()
model_processor.populate_models(session)
# populate tags
tags = (
"character",
"style",
"concept",
"clothing",
"pose",
"background",
"vehicle",
"object",
"animal",
"action",
)
for tag in tags:
session.add(Tag(name=tag))
session.commit()
def can_create_session():
return Session is not None
def create_session():
return Session()

View File

@@ -1,76 +0,0 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
Table,
ForeignKeyConstraint,
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
ModelTag = Table(
"model_tag",
Base.metadata,
Column(
"model_type",
Text,
primary_key=True,
),
Column(
"model_path",
Text,
primary_key=True,
),
Column("tag_id", Integer, primary_key=True),
ForeignKeyConstraint(
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
),
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
)
class Model(Base):
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
title = Column(Text)
description = Column(Text)
architecture = Column(Text)
hash = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
# Relationship with tags
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
def to_dict(self):
dict = to_dict(self)
dict["tags"] = [tag.to_dict() for tag in self.tags]
return dict
class Tag(Base):
__tablename__ = "tag"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=False, unique=True)
# Relationship with models
models = relationship("Model", secondary=ModelTag, back_populates="tags")
def to_dict(self):
return to_dict(self)

View File

@@ -1,30 +1,19 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
from app.database.db import create_session
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, get_full_path
from app.database.models import Tag, Model
from app.model_processor import get_model_previews, model_processor
from utils.web import dumps
from sqlalchemy.orm import joinedload
import sqlalchemy.exc
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
def bad_request(message: str):
return web.json_response({"error": message}, status=400)
def missing_field(field: str):
return bad_request(f"{field} is required")
def not_found(message: str):
return web.json_response({"error": message + " not found"}, status=404)
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
@@ -73,7 +62,7 @@ class ModelFileManager:
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = get_model_previews(full_filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
@@ -87,183 +76,6 @@ class ModelFileManager:
except:
return web.Response(status=404)
@routes.get("/v2/models")
async def get_models(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
query = session.query(Model).options(joinedload(Model.tags))
if model_path:
query = query.filter(Model.path == model_path)
if model_type:
query = query.filter(Model.type == model_type)
models = query.all()
if model_path and model_type:
if len(models) == 0:
return not_found("Model")
return web.json_response(models[0].to_dict(), dumps=dumps)
return web.json_response([model.to_dict() for model in models], dumps=dumps)
@routes.post("/v2/models")
async def add_model(request):
with create_session() as session:
data = await request.json()
model_type = data.get("type", None)
model_path = data.get("path", None)
if not model_type:
return missing_field("type")
if not model_path:
return missing_field("path")
tags = data.pop("tags", [])
fields = Model.metadata.tables["model"].columns.keys()
# Validate keys are valid model fields
for key in data.keys():
if key not in fields:
return bad_request(f"Invalid field: {key}")
# Validate file exists
if not get_full_path(model_type, model_path):
return not_found(f"File '{model_type}/{model_path}'")
model = Model()
for field in fields:
if field in data:
setattr(model, field, data[field])
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
for tag in tags:
if tag not in [t.id for t in model.tags]:
return not_found(f"Tag '{tag}'")
try:
session.add(model)
session.commit()
except sqlalchemy.exc.IntegrityError as e:
session.rollback()
return bad_request(e.orig.args[0])
model_processor.run()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models")
async def delete_model(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if not model_path:
return missing_field("path")
if not model_type:
return missing_field("type")
full_path = get_full_path(model_type, model_path)
if full_path:
return bad_request("Model file exists, please delete the file before deleting the model record.")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
session.delete(model)
session.commit()
return web.Response(status=204)
@routes.get("/v2/tags")
async def get_tags(request):
with create_session() as session:
tags = session.query(Tag).all()
return web.json_response(
[{"id": tag.id, "name": tag.name} for tag in tags]
)
@routes.post("/v2/tags")
async def create_tag(request):
with create_session() as session:
data = await request.json()
name = data.get("name", None)
if not name:
return missing_field("name")
tag = Tag(name=name)
session.add(tag)
session.commit()
return web.json_response({"id": tag.id, "name": tag.name})
@routes.delete("/v2/tags")
async def delete_tag(request):
with create_session() as session:
tag_id = request.query.get("id", None)
if not tag_id:
return missing_field("id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
return not_found("Tag")
session.delete(tag)
session.commit()
return web.Response(status=204)
@routes.post("/v2/models/tags")
async def add_model_tag(request):
with create_session() as session:
data = await request.json()
tag_id = data.get("tag", None)
model_path = data.get("path", None)
model_type = data.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags.append(tag)
session.commit()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models/tags")
async def delete_model_tag(request):
with create_session() as session:
tag_id = request.query.get("tag", None)
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags = [tag for tag in model.tags if tag.id != tag_id]
session.commit()
return web.Response(status=204)
@routes.get("/v2/models/missing")
async def get_missing_models(request):
return web.json_response(model_processor.missing_models)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
@@ -334,5 +146,39 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

View File

@@ -1,263 +0,0 @@
import base64
from datetime import datetime
import glob
import hashlib
from io import BytesIO
import json
import logging
import os
import threading
import time
import comfy.utils
from app.database.models import Model
from app.database.db import create_session
from comfy.cli_args import args
from folder_paths import (
filter_files_content_types,
get_full_path,
folder_names_and_paths,
get_filename_list,
)
from PIL import Image
from urllib import request
def get_model_previews(
filepath: str, check_metadata: bool = True
) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if not check_metadata:
return result
safetensors_file = next(
filter(lambda x: x.endswith(".safetensors"), match_files), None
)
safetensors_metadata = {}
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(
safetensors_filepath, max_size=8 * 1024 * 1024
)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
"ssmd_cover_images", None
)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
class ModelProcessor:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._run = False
self.missing_models = []
def run(self):
if args.disable_model_processing:
return
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
self._run = True
if self._thread is None:
self._thread = threading.Thread(target=self._process_models)
self._thread.daemon = True
self._thread.start()
def populate_models(self, session):
# Ensure database state matches filesystem
existing_models = session.query(Model).all()
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes" or folder_name == "configs":
continue
seen = set()
files = get_filename_list(folder_name)
for file in files:
if file in seen:
logging.warning(f"Skipping duplicate named model: {file}")
continue
seen.add(file)
existing_model = None
for model in existing_models:
if model.path == file and model.type == folder_name:
existing_model = model
break
if existing_model:
# Model already exists in db, remove from list and skip
existing_models.remove(existing_model)
continue
file_path = get_full_path(folder_name, file)
model = Model(
path=file,
type=folder_name,
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
)
session.add(model)
for model in existing_models:
if not get_full_path(model.type, model.path):
logging.warning(f"Model {model.path} not found")
self.missing_models.append({"type": model.type, "path": model.path})
session.commit()
def _get_models(self, session):
models = session.query(Model).filter(Model.hash == None).all()
return models
def _process_file(self, model_path):
is_safetensors = model_path.endswith(".safetensors")
metadata = {}
h = hashlib.sha256()
with open(model_path, "rb", buffering=0) as f:
if is_safetensors:
# Read header length (8 bytes)
header_size_bytes = f.read(8)
header_len = int.from_bytes(header_size_bytes, "little")
h.update(header_size_bytes)
# Read header
header_bytes = f.read(header_len)
h.update(header_bytes)
try:
metadata = json.loads(header_bytes)
except json.JSONDecodeError:
pass
# Read rest of file
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest(), metadata
def _populate_info(self, model, metadata):
model.title = metadata.get("modelspec.title", None)
model.description = metadata.get("modelspec.description", None)
model.architecture = metadata.get("modelspec.architecture", None)
def _extract_image(self, model_path, metadata):
# check if image already exists
if len(get_model_previews(model_path, check_metadata=False)) > 0:
return
image_path = os.path.splitext(model_path)[0] + ".webp"
if os.path.exists(image_path):
return
cover_images = metadata.get("ssmd_cover_images", None)
image = None
if cover_images:
try:
cover_images = json.loads(cover_images)
if len(cover_images) > 0:
image_data = cover_images[0]
image = Image.open(BytesIO(base64.b64decode(image_data)))
except Exception as e:
logging.warning(
f"Error extracting cover image for model {model_path}: {e}"
)
if not image:
thumbnail = metadata.get("modelspec.thumbnail", None)
if thumbnail:
try:
response = request.urlopen(thumbnail)
image = Image.open(response)
except Exception as e:
logging.warning(
f"Error extracting thumbnail for model {model_path}: {e}"
)
if image:
image.thumbnail((512, 512))
image.save(image_path)
image.close()
def _process_models(self):
with create_session() as session:
checked = set()
self.populate_models(session)
while self._run:
self._run = False
models = self._get_models(session)
if len(models) == 0:
break
for model in models:
# prevent looping on the same model if it crashes
if model.path in checked:
continue
checked.add(model.path)
try:
time.sleep(0)
now = time.time()
model_path = get_full_path(model.type, model.path)
if not model_path:
logging.warning(f"Model {model.path} not found")
self.missing_models.append(model.path)
continue
logging.debug(f"Processing model {model_path}")
hash, header = self._process_file(model_path)
logging.debug(
f"Processed model {model_path} in {time.time() - now} seconds"
)
model.hash = hash
if header:
metadata = header.get("__metadata__", None)
if metadata:
self._populate_info(model, metadata)
self._extract_image(model_path, metadata)
session.commit()
except Exception as e:
logging.error(f"Error processing model {model.path}: {e}")
with self._lock:
self._thread = None
model_processor = ModelProcessor()

View File

@@ -178,12 +178,6 @@ parser.add_argument(
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@@ -168,14 +168,18 @@ class Attention(nn.Module):
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
# apply_rotary_pos_emb inlined
q_shape = q.shape
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
def cal_attn(self, q, k, v, mask=None):
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
# apply_rotary_pos_emb inlined
k_shape = k.shape
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
return q, k, v
def forward(
self,
@@ -191,7 +195,10 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
return self.cal_attn(q, k, v, mask)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
class FeedForward(nn.Module):
@@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x = x + extra_per_block_pos_emb
for block in self.blocks:
x = block(
x,

View File

@@ -30,6 +30,8 @@ import torch.nn as nn
import torch.nn.functional as F
import logging
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from .patching import (
Patcher,
Patcher3D,
@@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.optimized_attention = vae_attention()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
@@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
v, batch_size = time2batch(v)
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.optimized_attention(q, k, v)
h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_)
@@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x)
# downsampling
hs = [self.conv_in(x)]
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
h = self.down[i_level].downsample(h)
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

View File

@@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlll
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhl
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhll
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhl
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xll
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xlh
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhl
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
del xl
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)

View File

@@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
self.build_pos_embed(device=device)
self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
@@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
def build_pos_embed(self, device=None):
def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
@@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
@@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
@@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
@@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
len_w: int,
len_t: int,
device=None,
dtype=None,
**kwargs,
):
"""
@@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:

View File

@@ -89,8 +89,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
logging.info(
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)

View File

@@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)

View File

@@ -5,8 +5,15 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)

View File

@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
self.optimized_attention = vae_attention()
def forward(self, x):
h_ = x

View File

@@ -388,8 +388,8 @@ class VAE:
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")

View File

@@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
memory_usage_factor = 1.8 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 2.4 #TODO
memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO

View File

@@ -71,8 +71,8 @@ class CosmosImageToVideoLatent:
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
out_latent["samples"] = latent
out_latent["noise_mask"] = mask
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.10"
__version__ = "0.3.11"

11
main.py
View File

@@ -138,8 +138,6 @@ import server
from server import BinaryEventTypes
import nodes
import comfy.model_management
from app.database.db import can_create_session, init_db
from app.model_processor import model_processor
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
@@ -264,11 +262,6 @@ def start_comfyui(asyncio_loop=None):
cuda_malloc_warning()
try:
init_db()
except Exception as e:
logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}")
prompt_server.add_routes()
hijack_progress(prompt_server)
@@ -276,10 +269,6 @@ def start_comfyui(asyncio_loop=None):
if args.quick_test_for_ci:
exit(0)
# Scan for changed model files and update db
if can_create_session():
model_processor.run()
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.10"
version = "0.3.11"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -2,6 +2,7 @@ torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
einops
transformers>=4.28.1
tokenizers>=0.13.3
@@ -13,8 +14,6 @@ Pillow
scipy
tqdm
psutil
alembic
SQLAlchemy
#non essential dependencies:
kornia>=0.7.1

View File

@@ -7,33 +7,11 @@ from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
from app.database.models import Base, Model, Tag
from comfy.cli_args import args
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def session():
# Configure in-memory database
args.database_url = "sqlite:///:memory:"
# Create engine and session factory
engine = create_engine(args.database_url)
Session = sessionmaker(bind=engine)
# Create all tables
Base.metadata.create_all(engine)
# Patch Session factory
with patch('app.database.db.Session', Session):
yield Session()
Base.metadata.drop_all(engine)
@pytest.fixture
def model_manager():
return ModelFileManager()
@@ -82,287 +60,3 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
# Clean up
img.close()
async def test_get_models(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/models')
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
assert data[0]['path'] == 'model1.safetensors'
assert len(data[0]['tags']) == 1
assert data[0]['tags'][0]['name'] == 'test_tag'
async def test_add_model(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
with patch('app.model_manager.model_processor') as mock_processor:
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Test Model',
'tags': [tag_id]
})
assert resp.status == 200
data = await resp.json()
assert data['path'] == 'model1.safetensors'
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
# Ensure that models are re-processed after adding
mock_processor.run.assert_called_once()
async def test_delete_model(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify model was deleted
model = session.query(Model).first()
assert model is None
async def test_delete_model_file_exists(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "file exists" in data["error"].lower()
# Verify model was not deleted
model = session.query(Model).first()
assert model is not None
assert model.path == 'model1.safetensors'
async def test_get_tags(aiohttp_client, app, session):
tags = [Tag(name='tag1'), Tag(name='tag2')]
for tag in tags:
session.add(tag)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/tags')
assert resp.status == 200
data = await resp.json()
assert len(data) == 2
assert {t['name'] for t in data} == {'tag1', 'tag2'}
async def test_create_tag(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
assert resp.status == 200
data = await resp.json()
assert data['name'] == 'new_tag'
# Verify tag was created
tag = session.query(Tag).first()
assert tag.name == 'new_tag'
async def test_delete_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/tags?id={tag_id}')
assert resp.status == 204
# Verify tag was deleted
tag = session.query(Tag).first()
assert tag is None
async def test_add_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'model1.safetensors'
})
assert resp.status == 200
data = await resp.json()
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
async def test_delete_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify tag was removed
model = session.query(Model).first()
assert len(model.tags) == 0
async def test_add_model_duplicate(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Duplicate Model'
})
assert resp.status == 400
async def test_add_model_missing_fields(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={})
assert resp.status == 400
async def test_add_tag_missing_name(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={})
assert resp.status == 400
async def test_delete_model_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
assert resp.status == 404
async def test_delete_tag_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/tags?id=999')
assert resp.status == 404
async def test_add_model_missing_path(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'title': 'Test Model'
})
assert resp.status == 400
data = await resp.json()
assert "path" in data["error"].lower()
async def test_add_model_invalid_field(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'invalid_field': 'some value'
})
assert resp.status == 400
data = await resp.json()
assert "invalid field" in data["error"].lower()
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "file" in data["error"].lower()
async def test_add_model_invalid_tag(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'tags': [999] # Non-existent tag ID
})
assert resp.status == 404
data = await resp.json()
assert "tag" in data["error"].lower()
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
# Create a tag but no model
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "model" in data["error"].lower()
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
# Create a model first
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "invalid tag id" in data["error"].lower()

View File

@@ -1,12 +0,0 @@
import json
from datetime import datetime
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
dumps = DateTimeEncoder().encode