1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 23:49:57 +08:00

Compare commits

..

2 Commits

Author SHA1 Message Date
pythongosssss
fde9fdddff Allow running with non working 2025-03-28 11:46:05 +08:00
pythongosssss
7bf381bc9e Add model management and database
- use sqlalchemy + alembic + sqlite for db
- extract model data and previews
- endpoints for db interactions
- add tests
2025-03-28 11:39:56 +08:00
23 changed files with 1257 additions and 1261 deletions

119
alembic.ini Normal file
View File

@@ -0,0 +1,119 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

3
alembic_db/README.md Normal file
View File

@@ -0,0 +1,3 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

75
alembic_db/env.py Normal file
View File

@@ -0,0 +1,75 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic_db/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,58 @@
"""init
Revision ID: 2fb22c4fff36
Revises:
Create Date: 2025-03-27 19:00:47.686079
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '2fb22c4fff36'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('architecture', sa.Text(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
op.create_table('tag',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
op.create_table('model_tag',
sa.Column('model_type', sa.Text(), nullable=False),
sa.Column('model_path', sa.Text(), nullable=False),
sa.Column('tag_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['model_type', 'model_path'], ['model.type', 'model.path'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('model_type', 'model_path', 'tag_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model_tag')
op.drop_table('tag')
op.drop_table('model')
# ### end Alembic commands ###

View File

@@ -1,126 +1,118 @@
import logging
import os
import sqlite3
from contextlib import contextmanager
from queue import Queue, Empty, Full
import threading
from app.database.updater import DatabaseUpdater
import folder_paths
import shutil
import sys
from app.database.models import Tag
from comfy.cli_args import args
try:
import alembic
import sqlalchemy
except ImportError as e:
req_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
)
logging.error(
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
)
exit(-1)
class Database:
def __init__(self, database_path=None, pool_size=1):
if database_path is None:
self.exists = False
database_path = "file::memory:?cache=shared"
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
Session = None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if current_rev != target_rev:
# Backup the database pre upgrade
db_path = get_db_path()
backup_path = db_path + ".bkp"
if os.path.exists(db_path):
shutil.copy(db_path, backup_path)
else:
self.exists = os.path.exists(database_path)
backup_path = None
self.database_path = database_path
self.pool_size = pool_size
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
self.connection_pool: Queue = Queue(maxsize=pool_size)
self._db_lock = threading.Lock()
self._initialized = False
self._closing = False
self._after_update_callbacks = []
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.error(f"Error upgrading database: {e}")
raise e
def _setup(self):
if self._initialized:
return
global Session
Session = sessionmaker(bind=engine)
with self._db_lock:
if not self._initialized:
self._make_db()
self._initialized = True
if not current_rev:
# Init db, populate models
from app.model_processor import model_processor
def _create_connection(self):
# TODO: Catch error for sqlite lib missing on linux
logging.info(f"Creating connection to {self.database_path}")
conn = sqlite3.connect(
self.database_path,
check_same_thread=False,
uri=self.database_path.startswith("file::"),
session = create_session()
model_processor.populate_models(session)
# populate tags
tags = (
"character",
"style",
"concept",
"clothing",
"pose",
"background",
"vehicle",
"object",
"animal",
"action",
)
conn.execute("PRAGMA foreign_keys = ON")
self.exists = True
logging.info(f"Connected!")
return conn
for tag in tags:
session.add(Tag(name=tag))
def _make_db(self):
with self._get_connection() as con:
updater = DatabaseUpdater(con, self.database_path)
result = updater.update()
if result is not None:
old_version, new_version = result
session.commit()
for callback in self._after_update_callbacks:
callback(old_version, new_version)
def can_create_session():
return Session is not None
def _transform(self, row, columns):
return {col.name: value for value, col in zip(row, columns)}
@contextmanager
def _get_connection(self):
if self._closing:
raise Exception("Database is shutting down")
try:
# Try to get connection from pool
connection = self.connection_pool.get_nowait()
except Empty:
# Create new connection if pool is empty
connection = self._create_connection()
try:
yield connection
finally:
try:
# Try to add to pool if it's empty
self.connection_pool.put_nowait(connection)
except Full:
# Pool is full, close the connection
connection.close()
@contextmanager
def get_connection(self):
# Setup the database if it's not already initialized
self._setup()
with self._get_connection() as connection:
yield connection
def execute(self, sql, *args):
with self.get_connection() as connection:
cursor = connection.execute(sql, args)
results = cursor.fetchall()
return results
def register_after_update_callback(self, callback):
self._after_update_callbacks.append(callback)
def close(self):
if self._closing:
return
# Drain and close all connections in the pool
self._closing = True
while True:
try:
conn = self.connection_pool.get_nowait()
conn.close()
except Empty:
break
self._closing = False
def __del__(self):
try:
self.close()
except:
pass
# Create a global instance
db_path = None
if not args.memory_database:
db_path = folder_paths.get_user_directory() + "/comfyui.db"
db = Database(db_path)
def create_session():
return Session()

View File

@@ -1,343 +0,0 @@
from typing import Optional, Any, Callable
from dataclasses import dataclass
from functools import wraps
from aiohttp import web
from app.database.db import db
primitives = (bool, str, int, float, type(None))
def is_primitive(obj):
return isinstance(obj, primitives)
class EntityError(Exception):
def __init__(
self, message: str, field: str = None, value: Any = None, status_code: int = 400
):
self.message = message
self.field = field
self.value = value
self.status_code = status_code
super().__init__(self.message)
def to_json(self):
result = {"message": self.message}
if self.field is not None:
result["field"] = self.field
if self.value is not None:
result["value"] = self.value
return result
def __str__(self) -> str:
return f"{self.message} {self.field} {self.value}"
class EntityCommon(dict):
@classmethod
def _get_route(cls, include_key: bool):
route = f"/db/{cls._table_name}"
if include_key:
route += "".join([f"/{{{k}}}" for k in cls._key_columns])
return route
@classmethod
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
route = cls._get_route(include_key)
@getattr(routes, verb)(route)
async def _(request):
try:
data = await handler(request)
if data is None:
return web.json_response(status=204)
return web.json_response(data)
except EntityError as e:
return web.json_response(e.to_json(), status=e.status_code)
@classmethod
def _transform(cls, row: list[Any]):
return {col: value for col, value in zip(cls._columns, row)}
@classmethod
def _transform_rows(cls, rows: list[list[Any]]):
return [cls._transform(row) for row in rows]
@classmethod
def _extract_key(cls, request):
return {key: request.match_info.get(key, None) for key in cls._key_columns}
@classmethod
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
result = {}
if not isinstance(data, dict):
raise EntityError("Invalid data")
# Ensure all required fields are present
for field in data:
if field not in fields:
raise EntityError("Unknown field", field)
for key in fields:
col = cls._columns[key]
if key not in data:
if col.required and not allow_missing:
raise EntityError("Missing field", key)
else:
# e.g. for updates, we allow missing fields
continue
elif data[key] is None and col.required:
# Dont allow None for required fields
raise EntityError("Required field", key)
# Validate data type
value = data[key]
if value is not None and not is_primitive(value):
raise EntityError("Invalid value", key, value)
try:
type = col.type
if value is not None and not isinstance(value, type):
value = type(value)
result[key] = value
except Exception:
raise EntityError("Invalid value", key, value)
return result
@classmethod
def _validate_id(cls, id: dict):
return cls._validate(cls._key_columns, id)
@classmethod
def _validate_data(cls, data: dict, allow_missing: bool = False):
return cls._validate(cls._columns.keys(), data, allow_missing)
def __setattr__(self, name, value):
if name in self._columns:
self[name] = value
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self:
return self[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class GetEntity(EntityCommon):
@classmethod
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
limit = ""
if top is not None and isinstance(top, int):
limit = f" LIMIT {top}"
result = db.execute(
f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}",
)
# Map each row in result to an instance of the class
return cls._transform_rows(result)
@classmethod
def register_route(cls, routes):
async def get_handler(request):
top = request.rel_url.query.get("top", None)
if top is not None:
try:
top = int(top)
except Exception:
raise EntityError("Invalid top parameter", "top", top)
return cls.get(top)
cls._register_route(routes, "get", False, get_handler)
class GetEntityById(EntityCommon):
@classmethod
def get_by_id(cls, id: dict):
id = cls._validate_id(id)
result = db.execute(
f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}",
*[id[key] for key in cls._key_columns],
)
return cls._transform_rows(result)
@classmethod
def register_route(cls, routes):
async def get_by_id_handler(request):
id = cls._extract_key(request)
return cls.get_by_id(id)
cls._register_route(routes, "get", True, get_by_id_handler)
class CreateEntity(EntityCommon):
@classmethod
def create(cls, data: dict, allow_upsert: bool = False):
data = cls._validate_data(data)
values = ", ".join(["?"] * len(data))
on_conflict = ""
data_keys = ", ".join(list(data.keys()))
if allow_upsert:
# Remove key columns from data
upsert_keys = [key for key in data if key not in cls._key_columns]
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}"
sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
result = db.execute(
sql,
*[data[key] for key in data],
)
if len(result) == 0:
raise EntityError("Failed to create entity", status_code=500)
return cls._transform_rows(result)[0]
@classmethod
def register_route(cls, routes):
async def create_handler(request):
data = await request.json()
return cls.create(data)
cls._register_route(routes, "post", False, create_handler)
class UpdateEntity(EntityCommon):
@classmethod
def update(cls, id: list, data: dict):
id = cls._validate_id(id)
data = cls._validate_data(data, allow_missing=True)
sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *"
result = db.execute(
sql,
*[data[key] for key in data],
*[id[key] for key in cls._key_columns],
)
if len(result) == 0:
raise EntityError("Failed to update entity", status_code=404)
return cls._transform_rows(result)[0]
@classmethod
def register_route(cls, routes):
async def update_handler(request):
id = cls._extract_key(request)
data = await request.json()
return cls.update(id, data)
cls._register_route(routes, "patch", True, update_handler)
class UpsertEntity(CreateEntity):
@classmethod
def upsert(cls, data: dict):
return cls.create(data, allow_upsert=True)
@classmethod
def register_route(cls, routes):
async def upsert_handler(request):
data = await request.json()
return cls.upsert(data)
cls._register_route(routes, "put", False, upsert_handler)
class DeleteEntity(EntityCommon):
@classmethod
def delete(cls, id: list):
id = cls._validate_id(id)
db.execute(
f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}",
*[id[key] for key in cls._key_columns],
)
@classmethod
def register_route(cls, routes):
async def delete_handler(request):
id = cls._extract_key(request)
cls.delete(id)
cls._register_route(routes, "delete", True, delete_handler)
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
pass
@dataclass
class Column:
type: Any
required: bool = False
key: bool = False
default: Any = None
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
return Column(type_, required, key, default)
def table(table_name: str):
def decorator(cls):
# Store table name
cls._table_name = table_name
# Process column definitions
columns: dict[str, Column] = {}
for attr_name, attr_value in cls.__dict__.items():
if isinstance(attr_value, Column):
columns[attr_name] = attr_value
# Store columns metadata
cls._columns = columns
cls._key_columns = [col for col in columns if columns[col].key]
cls._column_csv = ", ".join([col for col in columns])
cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns])
# Add initialization
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
# Initialize columns with default values
for col_name, col_def in cls._columns.items():
setattr(self, col_name, col_def.default)
# Call original init
original_init(self, *args, **kwargs)
cls.__init__ = new_init
return cls
return decorator
def test():
@table("models")
class Model(BaseEntity):
id: int = column(int, required=True, key=True)
path: str = column(str, required=True)
name: str = column(str, required=True)
description: Optional[str] = column(str)
architecture: Optional[str] = column(str)
type: str = column(str, required=True)
hash: Optional[str] = column(str)
source_url: Optional[str] = column(str)
return Model
@table("test")
class Test(GetEntity, CreateEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
Model = test()

76
app/database/models.py Normal file
View File

@@ -0,0 +1,76 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
Table,
ForeignKeyConstraint,
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
ModelTag = Table(
"model_tag",
Base.metadata,
Column(
"model_type",
Text,
primary_key=True,
),
Column(
"model_path",
Text,
primary_key=True,
),
Column("tag_id", Integer, primary_key=True),
ForeignKeyConstraint(
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
),
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
)
class Model(Base):
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
title = Column(Text)
description = Column(Text)
architecture = Column(Text)
hash = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
# Relationship with tags
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
def to_dict(self):
dict = to_dict(self)
dict["tags"] = [tag.to_dict() for tag in self.tags]
return dict
class Tag(Base):
__tablename__ = "tag"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=False, unique=True)
# Relationship with models
models = relationship("Model", secondary=ModelTag, back_populates="tags")
def to_dict(self):
return to_dict(self)

View File

@@ -1,32 +0,0 @@
from app.database.db import db
from aiohttp import web
def create_routes(
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
):
if get:
@routes.get(f"/{prefix}/{table}")
async def get_table(request):
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM {table}")
rows = cursor.fetchall()
return web.json_response(rows)
if get_by_id:
@routes.get(f"/{prefix}/{table}/{id}")
async def get_table_by_id(request):
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
row = cursor.fetchone()
return web.json_response(row)
if post:
@routes.post(f"/{prefix}/{table}")
async def post_table(request):
data = await request.json()
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
return web.json_response({"status": "success"})

View File

@@ -1,79 +0,0 @@
import logging
import os
import sqlite3
from app.database.versions.v1 import v1
class DatabaseUpdater:
def __init__(self, connection, database_path):
self.connection = connection
self.database_path = database_path
self.current_version = self.get_db_version()
self.version_updates = {
1: v1,
}
self.max_version = max(self.version_updates.keys())
self.update_required = self.current_version < self.max_version
logging.info(f"Database version: {self.current_version}")
def get_db_version(self):
return self.connection.execute("PRAGMA user_version").fetchone()[0]
def backup(self):
bkp_path = self.database_path + ".bkp"
if os.path.exists(bkp_path):
# TODO: auto-rollback failed upgrades
raise Exception(
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
)
bkp = sqlite3.connect(bkp_path)
self.connection.backup(bkp)
bkp.close()
logging.info("Database backup taken pre-upgrade.")
return bkp_path
def update(self):
if not self.update_required:
return None
bkp_version = self.current_version
bkp_path = None
if self.current_version > 0:
bkp_path = self.backup()
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
dirname = os.path.dirname(__file__)
cursor = self.connection.cursor()
for version in range(self.current_version + 1, self.max_version + 1):
filename = os.path.join(dirname, f"versions/v{version}.sql")
if not os.path.exists(filename):
raise Exception(
f"Database update script for version {version} not found"
)
try:
with open(filename, "r") as file:
sql = file.read()
cursor.executescript(sql)
except Exception as e:
raise Exception(
f"Failed to execute update script for version {version}: {e}"
)
method = self.version_updates[version]
if method is not None:
method(cursor)
cursor.execute("PRAGMA user_version = %d" % self.max_version)
self.connection.commit()
cursor.close()
self.current_version = self.get_db_version()
if bkp_path:
# Keep a copy of the backup in case something goes wrong and we need to rollback
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
logging.info(f"Upgrade to successful.")
return (bkp_version, self.current_version)

View File

@@ -1,17 +0,0 @@
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
def v1(cursor):
print("Updating to v1")
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes":
continue
files = get_filename_list(folder_name)
for file in files:
file_path = get_full_path(folder_name, file)
file_without_extension = file.rsplit(".", maxsplit=1)[0]
cursor.execute(
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
(file_path, file_without_extension, folder_name),
)

View File

@@ -1,41 +0,0 @@
CREATE TABLE IF NOT EXISTS
models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
architecture TEXT,
type TEXT NOT NULL,
hash TEXT,
source_url TEXT
);
CREATE TABLE IF NOT EXISTS
tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS
model_tags (
model_id INTEGER NOT NULL,
tag_id INTEGER NOT NULL,
PRIMARY KEY (model_id, tag_id),
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
);
INSERT INTO
tags (name)
VALUES
('character'),
('style'),
('concept'),
('clothing'),
('poses'),
('background'),
('vehicle'),
('buildings'),
('objects'),
('animal'),
('action');

View File

@@ -1,63 +0,0 @@
import hashlib
import logging
import threading
import time
from comfy.cli_args import args
class ModelHasher:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._model_entity = None
def start(self, model_entity):
if args.disable_model_hashing:
return
self._model_entity = model_entity
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
if self._thread is None:
self._thread = threading.Thread(target=self._hash_models)
self._thread.daemon = True
self._thread.start()
def _get_models(self):
models = self._model_entity.get("WHERE hash IS NULL")
return models
def _hash_model(self, model_path):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(model_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
hash = h.hexdigest()
return hash
def _hash_models(self):
while True:
models = self._get_models()
if len(models) == 0:
break
for model in models:
time.sleep(0)
now = time.time()
logging.info(f"Hashing model {model['path']}")
hash = self._hash_model(model["path"])
logging.info(
f"Hashed model {model['path']} in {time.time() - now} seconds"
)
self._model_entity.update((model["id"],), {"hash": hash})
self._thread = None
model_hasher = ModelHasher()

View File

@@ -1,19 +1,30 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
from app.database.db import create_session
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
from folder_paths import map_legacy, filter_files_extensions, get_full_path
from app.database.models import Tag, Model
from app.model_processor import get_model_previews, model_processor
from utils.web import dumps
from sqlalchemy.orm import joinedload
import sqlalchemy.exc
def bad_request(message: str):
return web.json_response({"error": message}, status=400)
def missing_field(field: str):
return bad_request(f"{field} is required")
def not_found(message: str):
return web.json_response({"error": message + " not found"}, status=404)
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
@@ -62,7 +73,7 @@ class ModelFileManager:
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
previews = get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
@@ -76,6 +87,183 @@ class ModelFileManager:
except:
return web.Response(status=404)
@routes.get("/v2/models")
async def get_models(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
query = session.query(Model).options(joinedload(Model.tags))
if model_path:
query = query.filter(Model.path == model_path)
if model_type:
query = query.filter(Model.type == model_type)
models = query.all()
if model_path and model_type:
if len(models) == 0:
return not_found("Model")
return web.json_response(models[0].to_dict(), dumps=dumps)
return web.json_response([model.to_dict() for model in models], dumps=dumps)
@routes.post("/v2/models")
async def add_model(request):
with create_session() as session:
data = await request.json()
model_type = data.get("type", None)
model_path = data.get("path", None)
if not model_type:
return missing_field("type")
if not model_path:
return missing_field("path")
tags = data.pop("tags", [])
fields = Model.metadata.tables["model"].columns.keys()
# Validate keys are valid model fields
for key in data.keys():
if key not in fields:
return bad_request(f"Invalid field: {key}")
# Validate file exists
if not get_full_path(model_type, model_path):
return not_found(f"File '{model_type}/{model_path}'")
model = Model()
for field in fields:
if field in data:
setattr(model, field, data[field])
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
for tag in tags:
if tag not in [t.id for t in model.tags]:
return not_found(f"Tag '{tag}'")
try:
session.add(model)
session.commit()
except sqlalchemy.exc.IntegrityError as e:
session.rollback()
return bad_request(e.orig.args[0])
model_processor.run()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models")
async def delete_model(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if not model_path:
return missing_field("path")
if not model_type:
return missing_field("type")
full_path = get_full_path(model_type, model_path)
if full_path:
return bad_request("Model file exists, please delete the file before deleting the model record.")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
session.delete(model)
session.commit()
return web.Response(status=204)
@routes.get("/v2/tags")
async def get_tags(request):
with create_session() as session:
tags = session.query(Tag).all()
return web.json_response(
[{"id": tag.id, "name": tag.name} for tag in tags]
)
@routes.post("/v2/tags")
async def create_tag(request):
with create_session() as session:
data = await request.json()
name = data.get("name", None)
if not name:
return missing_field("name")
tag = Tag(name=name)
session.add(tag)
session.commit()
return web.json_response({"id": tag.id, "name": tag.name})
@routes.delete("/v2/tags")
async def delete_tag(request):
with create_session() as session:
tag_id = request.query.get("id", None)
if not tag_id:
return missing_field("id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
return not_found("Tag")
session.delete(tag)
session.commit()
return web.Response(status=204)
@routes.post("/v2/models/tags")
async def add_model_tag(request):
with create_session() as session:
data = await request.json()
tag_id = data.get("tag", None)
model_path = data.get("path", None)
model_type = data.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags.append(tag)
session.commit()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models/tags")
async def delete_model_tag(request):
with create_session() as session:
tag_id = request.query.get("tag", None)
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags = [tag for tag in model.tags if tag.id != tag_id]
session.commit()
return web.Response(status=204)
@routes.get("/v2/models/missing")
async def get_missing_models(request):
return web.json_response(model_processor.missing_models)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
@@ -146,39 +334,5 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

263
app/model_processor.py Normal file
View File

@@ -0,0 +1,263 @@
import base64
from datetime import datetime
import glob
import hashlib
from io import BytesIO
import json
import logging
import os
import threading
import time
import comfy.utils
from app.database.models import Model
from app.database.db import create_session
from comfy.cli_args import args
from folder_paths import (
filter_files_content_types,
get_full_path,
folder_names_and_paths,
get_filename_list,
)
from PIL import Image
from urllib import request
def get_model_previews(
filepath: str, check_metadata: bool = True
) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if not check_metadata:
return result
safetensors_file = next(
filter(lambda x: x.endswith(".safetensors"), match_files), None
)
safetensors_metadata = {}
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(
safetensors_filepath, max_size=8 * 1024 * 1024
)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
"ssmd_cover_images", None
)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
class ModelProcessor:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._run = False
self.missing_models = []
def run(self):
if args.disable_model_processing:
return
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
self._run = True
if self._thread is None:
self._thread = threading.Thread(target=self._process_models)
self._thread.daemon = True
self._thread.start()
def populate_models(self, session):
# Ensure database state matches filesystem
existing_models = session.query(Model).all()
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes" or folder_name == "configs":
continue
seen = set()
files = get_filename_list(folder_name)
for file in files:
if file in seen:
logging.warning(f"Skipping duplicate named model: {file}")
continue
seen.add(file)
existing_model = None
for model in existing_models:
if model.path == file and model.type == folder_name:
existing_model = model
break
if existing_model:
# Model already exists in db, remove from list and skip
existing_models.remove(existing_model)
continue
file_path = get_full_path(folder_name, file)
model = Model(
path=file,
type=folder_name,
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
)
session.add(model)
for model in existing_models:
if not get_full_path(model.type, model.path):
logging.warning(f"Model {model.path} not found")
self.missing_models.append({"type": model.type, "path": model.path})
session.commit()
def _get_models(self, session):
models = session.query(Model).filter(Model.hash == None).all()
return models
def _process_file(self, model_path):
is_safetensors = model_path.endswith(".safetensors")
metadata = {}
h = hashlib.sha256()
with open(model_path, "rb", buffering=0) as f:
if is_safetensors:
# Read header length (8 bytes)
header_size_bytes = f.read(8)
header_len = int.from_bytes(header_size_bytes, "little")
h.update(header_size_bytes)
# Read header
header_bytes = f.read(header_len)
h.update(header_bytes)
try:
metadata = json.loads(header_bytes)
except json.JSONDecodeError:
pass
# Read rest of file
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest(), metadata
def _populate_info(self, model, metadata):
model.title = metadata.get("modelspec.title", None)
model.description = metadata.get("modelspec.description", None)
model.architecture = metadata.get("modelspec.architecture", None)
def _extract_image(self, model_path, metadata):
# check if image already exists
if len(get_model_previews(model_path, check_metadata=False)) > 0:
return
image_path = os.path.splitext(model_path)[0] + ".webp"
if os.path.exists(image_path):
return
cover_images = metadata.get("ssmd_cover_images", None)
image = None
if cover_images:
try:
cover_images = json.loads(cover_images)
if len(cover_images) > 0:
image_data = cover_images[0]
image = Image.open(BytesIO(base64.b64decode(image_data)))
except Exception as e:
logging.warning(
f"Error extracting cover image for model {model_path}: {e}"
)
if not image:
thumbnail = metadata.get("modelspec.thumbnail", None)
if thumbnail:
try:
response = request.urlopen(thumbnail)
image = Image.open(response)
except Exception as e:
logging.warning(
f"Error extracting thumbnail for model {model_path}: {e}"
)
if image:
image.thumbnail((512, 512))
image.save(image_path)
image.close()
def _process_models(self):
with create_session() as session:
checked = set()
self.populate_models(session)
while self._run:
self._run = False
models = self._get_models(session)
if len(models) == 0:
break
for model in models:
# prevent looping on the same model if it crashes
if model.path in checked:
continue
checked.add(model.path)
try:
time.sleep(0)
now = time.time()
model_path = get_full_path(model.type, model.path)
if not model_path:
logging.warning(f"Model {model.path} not found")
self.missing_models.append(model.path)
continue
logging.debug(f"Processing model {model_path}")
hash, header = self._process_file(model_path)
logging.debug(
f"Processed model {model_path} in {time.time() - now} seconds"
)
model.hash = hash
if header:
metadata = header.get("__metadata__", None)
if metadata:
self._populate_info(model, metadata)
self._extract_image(model_path, metadata)
session.commit()
except Exception as e:
logging.error(f"Error processing model {model.path}: {e}")
with self._lock:
self._thread = None
model_processor = ModelProcessor()

View File

@@ -143,13 +143,9 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
parser.add_argument(
"--front-end-version",
type=str,
@@ -182,6 +178,12 @@ parser.add_argument(
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

11
main.py
View File

@@ -138,6 +138,8 @@ import server
from server import BinaryEventTypes
import nodes
import comfy.model_management
from app.database.db import can_create_session, init_db
from app.model_processor import model_processor
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
@@ -262,6 +264,11 @@ def start_comfyui(asyncio_loop=None):
cuda_malloc_warning()
try:
init_db()
except Exception as e:
logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}")
prompt_server.add_routes()
hijack_progress(prompt_server)
@@ -269,6 +276,10 @@ def start_comfyui(asyncio_loop=None):
if args.quick_test_for_ci:
exit(0)
# Scan for changed model files and update db
if can_create_session():
model_processor.run()
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None

View File

@@ -13,6 +13,8 @@ Pillow
scipy
tqdm
psutil
alembic
SQLAlchemy
#non essential dependencies:
kornia>=0.7.1

View File

@@ -34,9 +34,6 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
from app.database.entities import get_entity, init_entities
from app.database.db import db
from app.model_hasher import model_hasher
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@@ -685,25 +682,11 @@ class PromptServer():
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def init_db(self, routes):
init_entities(routes)
models = get_entity("models")
if db.exists:
model_hasher.start(models)
else:
def on_db_update(_, __):
model_hasher.start(models)
db.register_after_update_callback(on_db_update)
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())
self.init_db(self.routes)
# Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward

View File

@@ -1,513 +0,0 @@
from comfy.cli_args import args
args.memory_database = True # force in-memory database for testing
from typing import Callable, Optional
import pytest
import pytest_asyncio
from unittest.mock import patch
from aiohttp import web
from app.database.entities import (
DeleteEntity,
column,
table,
Column,
GetEntity,
GetEntityById,
CreateEntity,
UpsertEntity,
UpdateEntity,
)
from app.database.db import db
pytestmark = pytest.mark.asyncio
def create_table(entity):
# reset db
db.close()
cols: list[Column] = entity._columns
# Create tables as temporary so when we close the db, the tables are dropped for next test
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
for col_name, col in cols.items():
type = None
if col.type == int:
type = "INTEGER"
elif col.type == str:
type = "TEXT"
sql += f"{col_name} {type}"
if col.required:
sql += " NOT NULL"
sql += ", "
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
sql += ")"
db.execute(sql)
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
with patch.object(db, "execute", wraps=db.execute) as mock:
response = await method()
assert mock.call_count == 1
assert mock.call_args[0][0] == expected_sql
assert mock.call_args[0][1:] == expected_args
return response
@pytest.fixture
def getable_entity():
@table("getable_entity")
class GetableEntity(GetEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
nullable: Optional[str] = column(str)
return GetableEntity
@pytest.fixture
def getable_by_id_entity():
@table("getable_by_id_entity")
class GetableByIdEntity(GetEntityById):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
return GetableByIdEntity
@pytest.fixture
def getable_by_id_composite_entity():
@table("getable_by_id_composite_entity")
class GetableByIdCompositeEntity(GetEntityById):
id1: str = column(str, required=True, key=True)
id2: int = column(int, required=True, key=True)
test: str = column(str, required=True)
return GetableByIdCompositeEntity
@pytest.fixture
def creatable_entity():
@table("creatable_entity")
class CreatableEntity(CreateEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
reqd: str = column(str, required=True)
nullable: Optional[str] = column(str)
return CreatableEntity
@pytest.fixture
def upsertable_entity():
@table("upsertable_entity")
class UpsertableEntity(UpsertEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
reqd: str = column(str, required=True)
nullable: Optional[str] = column(str)
return UpsertableEntity
@pytest.fixture
def updateable_entity():
@table("updateable_entity")
class UpdateableEntity(UpdateEntity):
id: int = column(int, required=True, key=True)
reqd: str = column(str, required=True)
return UpdateableEntity
@pytest.fixture
def deletable_entity():
@table("deletable_entity")
class DeletableEntity(DeleteEntity):
id: int = column(int, required=True, key=True)
return DeletableEntity
@pytest.fixture
def deletable_composite_entity():
@table("deletable_composite_entity")
class DeletableCompositeEntity(DeleteEntity):
id1: str = column(str, required=True, key=True)
id2: int = column(int, required=True, key=True)
return DeletableCompositeEntity
@pytest.fixture()
def entity(request):
value = request.getfixturevalue(request.param)
create_table(value)
return value
@pytest_asyncio.fixture
async def client(aiohttp_client, app):
return await aiohttp_client(app)
@pytest.fixture
def app(entity):
app = web.Application()
routes = web.RouteTableDef()
entity.register_route(routes)
app.add_routes(routes)
return app
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_empty_response(client):
expected_sql = "SELECT * FROM getable_entity"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
)
assert response.status == 200
assert await response.json() == []
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_data(client):
# seed db
db.execute(
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
)
expected_sql = "SELECT * FROM getable_entity"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1", "nullable": None},
{"id": 2, "test": "test2", "nullable": "test2"},
]
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_top_parameter(client):
# seed with 3 rows
db.execute(
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
)
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity?top=2"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1", "nullable": None},
{"id": 2, "test": "test2", "nullable": "test2"},
]
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_invalid_top_parameter(client):
response = await client.get("/db/getable_entity?top=hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid top parameter",
"field": "top",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
async def test_get_model_by_id_empty_response(client):
# seed db
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
expected_args = (1,)
response = await wrap_db(
lambda: client.get("/db/getable_by_id_entity/1"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1"},
]
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
async def test_get_model_by_id_with_invalid_id(client):
response = await client.get("/db/getable_by_id_entity/hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
async def test_get_model_by_id_composite(client):
# seed db
db.execute(
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
)
expected_sql = (
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
)
expected_args = ("one", 2)
response = await wrap_db(
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id1": "one", "id2": 2, "test": "test"},
]
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
async def test_get_model_by_id_composite_with_invalid_id(client):
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id2",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model(client):
expected_sql = (
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
)
expected_args = (1, "test1", "reqd1")
response = await wrap_db(
lambda: client.post(
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"test": "test1",
"reqd": "reqd1",
"nullable": None,
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_missing_required_field(client):
response = await client.post(
"/db/creatable_entity", json={"id": 1, "test": "test1"}
)
assert response.status == 400
assert await response.json() == {
"message": "Missing field",
"field": "reqd",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_missing_key_field(client):
response = await client.post(
"/db/creatable_entity",
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
)
assert response.status == 400
assert await response.json() == {
"message": "Missing field",
"field": "id",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_key_data(client):
response = await client.post(
"/db/creatable_entity",
json={
"id": "not_an_integer",
"test": "test1",
"reqd": "reqd1",
}, # id should be int
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "not_an_integer",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_data(client):
response = await client.post(
"/db/creatable_entity",
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "aaa",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_type(client):
response = await client.post(
"/db/creatable_entity",
json={
"id": 1,
"test": ["invalid_array"],
"reqd": "reqd1",
}, # test should be string
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "test",
"value": ["invalid_array"],
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_name(client):
response = await client.post(
"/db/creatable_entity",
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
)
assert response.status == 400
assert await response.json() == {
"message": "Unknown field",
"field": "nonexistent_field",
}
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
async def test_upsert_model(client):
expected_sql = (
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
"RETURNING *"
)
expected_args = (1, "test1", "reqd1")
response = await wrap_db(
lambda: client.put(
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"test": "test1",
"reqd": "reqd1",
"nullable": None,
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model(client):
# seed db
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
expected_args = ("updated_test", 1)
response = await wrap_db(
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"reqd": "updated_test",
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model_reject_null_required_field(client):
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
assert response.status == 400
assert await response.json() == {
"message": "Required field",
"field": "reqd",
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model_reject_invalid_field(client):
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
assert response.status == 400
assert await response.json() == {
"message": "Unknown field",
"field": "hello",
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model_reject_missing_record(client):
response = await client.patch(
"/db/updateable_entity/1", json={"reqd": "updated_test"}
)
assert response.status == 404
assert await response.json() == {
"message": "Failed to update entity",
}
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
async def test_delete_model(client):
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
expected_args = (1,)
response = await wrap_db(
lambda: client.delete("/db/deletable_entity/1"),
expected_sql,
expected_args,
)
assert response.status == 204
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
async def test_delete_model_composite_key(client):
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
expected_args = ("one", 2)
response = await wrap_db(
lambda: client.delete("/db/deletable_composite_entity/one/2"),
expected_sql,
expected_args,
)
assert response.status == 204

View File

@@ -7,11 +7,33 @@ from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
from app.database.models import Base, Model, Tag
from comfy.cli_args import args
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def session():
# Configure in-memory database
args.database_url = "sqlite:///:memory:"
# Create engine and session factory
engine = create_engine(args.database_url)
Session = sessionmaker(bind=engine)
# Create all tables
Base.metadata.create_all(engine)
# Patch Session factory
with patch('app.database.db.Session', Session):
yield Session()
Base.metadata.drop_all(engine)
@pytest.fixture
def model_manager():
return ModelFileManager()
@@ -60,3 +82,287 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
# Clean up
img.close()
async def test_get_models(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/models')
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
assert data[0]['path'] == 'model1.safetensors'
assert len(data[0]['tags']) == 1
assert data[0]['tags'][0]['name'] == 'test_tag'
async def test_add_model(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
with patch('app.model_manager.model_processor') as mock_processor:
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Test Model',
'tags': [tag_id]
})
assert resp.status == 200
data = await resp.json()
assert data['path'] == 'model1.safetensors'
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
# Ensure that models are re-processed after adding
mock_processor.run.assert_called_once()
async def test_delete_model(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify model was deleted
model = session.query(Model).first()
assert model is None
async def test_delete_model_file_exists(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "file exists" in data["error"].lower()
# Verify model was not deleted
model = session.query(Model).first()
assert model is not None
assert model.path == 'model1.safetensors'
async def test_get_tags(aiohttp_client, app, session):
tags = [Tag(name='tag1'), Tag(name='tag2')]
for tag in tags:
session.add(tag)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/tags')
assert resp.status == 200
data = await resp.json()
assert len(data) == 2
assert {t['name'] for t in data} == {'tag1', 'tag2'}
async def test_create_tag(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
assert resp.status == 200
data = await resp.json()
assert data['name'] == 'new_tag'
# Verify tag was created
tag = session.query(Tag).first()
assert tag.name == 'new_tag'
async def test_delete_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/tags?id={tag_id}')
assert resp.status == 204
# Verify tag was deleted
tag = session.query(Tag).first()
assert tag is None
async def test_add_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'model1.safetensors'
})
assert resp.status == 200
data = await resp.json()
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
async def test_delete_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify tag was removed
model = session.query(Model).first()
assert len(model.tags) == 0
async def test_add_model_duplicate(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Duplicate Model'
})
assert resp.status == 400
async def test_add_model_missing_fields(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={})
assert resp.status == 400
async def test_add_tag_missing_name(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={})
assert resp.status == 400
async def test_delete_model_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
assert resp.status == 404
async def test_delete_tag_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/tags?id=999')
assert resp.status == 404
async def test_add_model_missing_path(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'title': 'Test Model'
})
assert resp.status == 400
data = await resp.json()
assert "path" in data["error"].lower()
async def test_add_model_invalid_field(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'invalid_field': 'some value'
})
assert resp.status == 400
data = await resp.json()
assert "invalid field" in data["error"].lower()
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "file" in data["error"].lower()
async def test_add_model_invalid_tag(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'tags': [999] # Non-existent tag ID
})
assert resp.status == 404
data = await resp.json()
assert "tag" in data["error"].lower()
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
# Create a tag but no model
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "model" in data["error"].lower()
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
# Create a model first
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "invalid tag id" in data["error"].lower()

12
utils/web.py Normal file
View File

@@ -0,0 +1,12 @@
import json
from datetime import datetime
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
dumps = DateTimeEncoder().encode