mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 23:49:57 +08:00
Compare commits
13 Commits
model_mana
...
v0.3.11
Author | SHA1 | Date | |
---|---|---|---|
|
619b8cde74 | ||
|
31831e6ef1 | ||
|
88ceb28e20 | ||
|
23289a6a5c | ||
|
9d8b6c1f46 | ||
|
6320d05696 | ||
|
25683b5b02 | ||
|
4758fb64b9 | ||
|
008761166f | ||
|
bfd5dfd611 | ||
|
55ade36d01 | ||
|
2e20e399ea | ||
|
3baf92d120 |
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@@ -1,126 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue, Empty, Full
|
||||
import threading
|
||||
from app.database.updater import DatabaseUpdater
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, database_path=None, pool_size=1):
|
||||
if database_path is None:
|
||||
self.exists = False
|
||||
database_path = "file::memory:?cache=shared"
|
||||
else:
|
||||
self.exists = os.path.exists(database_path)
|
||||
|
||||
self.database_path = database_path
|
||||
self.pool_size = pool_size
|
||||
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
|
||||
self.connection_pool: Queue = Queue(maxsize=pool_size)
|
||||
self._db_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
self._closing = False
|
||||
self._after_update_callbacks = []
|
||||
|
||||
def _setup(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
with self._db_lock:
|
||||
if not self._initialized:
|
||||
self._make_db()
|
||||
self._initialized = True
|
||||
|
||||
def _create_connection(self):
|
||||
# TODO: Catch error for sqlite lib missing on linux
|
||||
logging.info(f"Creating connection to {self.database_path}")
|
||||
conn = sqlite3.connect(
|
||||
self.database_path,
|
||||
check_same_thread=False,
|
||||
uri=self.database_path.startswith("file::"),
|
||||
)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
self.exists = True
|
||||
logging.info(f"Connected!")
|
||||
return conn
|
||||
|
||||
def _make_db(self):
|
||||
with self._get_connection() as con:
|
||||
updater = DatabaseUpdater(con, self.database_path)
|
||||
result = updater.update()
|
||||
if result is not None:
|
||||
old_version, new_version = result
|
||||
|
||||
for callback in self._after_update_callbacks:
|
||||
callback(old_version, new_version)
|
||||
|
||||
def _transform(self, row, columns):
|
||||
return {col.name: value for value, col in zip(row, columns)}
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
if self._closing:
|
||||
raise Exception("Database is shutting down")
|
||||
|
||||
try:
|
||||
# Try to get connection from pool
|
||||
connection = self.connection_pool.get_nowait()
|
||||
except Empty:
|
||||
# Create new connection if pool is empty
|
||||
connection = self._create_connection()
|
||||
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
try:
|
||||
# Try to add to pool if it's empty
|
||||
self.connection_pool.put_nowait(connection)
|
||||
except Full:
|
||||
# Pool is full, close the connection
|
||||
connection.close()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
# Setup the database if it's not already initialized
|
||||
self._setup()
|
||||
with self._get_connection() as connection:
|
||||
yield connection
|
||||
|
||||
def execute(self, sql, *args):
|
||||
with self.get_connection() as connection:
|
||||
cursor = connection.execute(sql, args)
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def register_after_update_callback(self, callback):
|
||||
self._after_update_callbacks.append(callback)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
# Drain and close all connections in the pool
|
||||
self._closing = True
|
||||
while True:
|
||||
try:
|
||||
conn = self.connection_pool.get_nowait()
|
||||
conn.close()
|
||||
except Empty:
|
||||
break
|
||||
self._closing = False
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Create a global instance
|
||||
db_path = None
|
||||
if not args.memory_database:
|
||||
db_path = folder_paths.get_user_directory() + "/comfyui.db"
|
||||
db = Database(db_path)
|
@@ -1,343 +0,0 @@
|
||||
from typing import Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from aiohttp import web
|
||||
from app.database.db import db
|
||||
|
||||
primitives = (bool, str, int, float, type(None))
|
||||
|
||||
|
||||
def is_primitive(obj):
|
||||
return isinstance(obj, primitives)
|
||||
|
||||
|
||||
class EntityError(Exception):
|
||||
def __init__(
|
||||
self, message: str, field: str = None, value: Any = None, status_code: int = 400
|
||||
):
|
||||
self.message = message
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_json(self):
|
||||
result = {"message": self.message}
|
||||
if self.field is not None:
|
||||
result["field"] = self.field
|
||||
if self.value is not None:
|
||||
result["value"] = self.value
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message} {self.field} {self.value}"
|
||||
|
||||
|
||||
class EntityCommon(dict):
|
||||
@classmethod
|
||||
def _get_route(cls, include_key: bool):
|
||||
route = f"/db/{cls._table_name}"
|
||||
if include_key:
|
||||
route += "".join([f"/{{{k}}}" for k in cls._key_columns])
|
||||
return route
|
||||
|
||||
@classmethod
|
||||
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
|
||||
route = cls._get_route(include_key)
|
||||
|
||||
@getattr(routes, verb)(route)
|
||||
async def _(request):
|
||||
try:
|
||||
data = await handler(request)
|
||||
if data is None:
|
||||
return web.json_response(status=204)
|
||||
|
||||
return web.json_response(data)
|
||||
except EntityError as e:
|
||||
return web.json_response(e.to_json(), status=e.status_code)
|
||||
|
||||
@classmethod
|
||||
def _transform(cls, row: list[Any]):
|
||||
return {col: value for col, value in zip(cls._columns, row)}
|
||||
|
||||
@classmethod
|
||||
def _transform_rows(cls, rows: list[list[Any]]):
|
||||
return [cls._transform(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
def _extract_key(cls, request):
|
||||
return {key: request.match_info.get(key, None) for key in cls._key_columns}
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
|
||||
result = {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise EntityError("Invalid data")
|
||||
|
||||
# Ensure all required fields are present
|
||||
for field in data:
|
||||
if field not in fields:
|
||||
raise EntityError("Unknown field", field)
|
||||
|
||||
for key in fields:
|
||||
col = cls._columns[key]
|
||||
if key not in data:
|
||||
if col.required and not allow_missing:
|
||||
raise EntityError("Missing field", key)
|
||||
else:
|
||||
# e.g. for updates, we allow missing fields
|
||||
continue
|
||||
elif data[key] is None and col.required:
|
||||
# Dont allow None for required fields
|
||||
raise EntityError("Required field", key)
|
||||
|
||||
# Validate data type
|
||||
value = data[key]
|
||||
|
||||
if value is not None and not is_primitive(value):
|
||||
raise EntityError("Invalid value", key, value)
|
||||
|
||||
try:
|
||||
type = col.type
|
||||
if value is not None and not isinstance(value, type):
|
||||
value = type(value)
|
||||
result[key] = value
|
||||
except Exception:
|
||||
raise EntityError("Invalid value", key, value)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _validate_id(cls, id: dict):
|
||||
return cls._validate(cls._key_columns, id)
|
||||
|
||||
@classmethod
|
||||
def _validate_data(cls, data: dict, allow_missing: bool = False):
|
||||
return cls._validate(cls._columns.keys(), data, allow_missing)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self._columns:
|
||||
self[name] = value
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self:
|
||||
return self[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
class GetEntity(EntityCommon):
|
||||
@classmethod
|
||||
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
|
||||
limit = ""
|
||||
if top is not None and isinstance(top, int):
|
||||
limit = f" LIMIT {top}"
|
||||
result = db.execute(
|
||||
f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}",
|
||||
)
|
||||
|
||||
# Map each row in result to an instance of the class
|
||||
return cls._transform_rows(result)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def get_handler(request):
|
||||
top = request.rel_url.query.get("top", None)
|
||||
if top is not None:
|
||||
try:
|
||||
top = int(top)
|
||||
except Exception:
|
||||
raise EntityError("Invalid top parameter", "top", top)
|
||||
return cls.get(top)
|
||||
|
||||
cls._register_route(routes, "get", False, get_handler)
|
||||
|
||||
|
||||
class GetEntityById(EntityCommon):
|
||||
@classmethod
|
||||
def get_by_id(cls, id: dict):
|
||||
id = cls._validate_id(id)
|
||||
|
||||
result = db.execute(
|
||||
f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||
*[id[key] for key in cls._key_columns],
|
||||
)
|
||||
|
||||
return cls._transform_rows(result)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def get_by_id_handler(request):
|
||||
id = cls._extract_key(request)
|
||||
return cls.get_by_id(id)
|
||||
|
||||
cls._register_route(routes, "get", True, get_by_id_handler)
|
||||
|
||||
|
||||
class CreateEntity(EntityCommon):
|
||||
@classmethod
|
||||
def create(cls, data: dict, allow_upsert: bool = False):
|
||||
data = cls._validate_data(data)
|
||||
values = ", ".join(["?"] * len(data))
|
||||
on_conflict = ""
|
||||
|
||||
data_keys = ", ".join(list(data.keys()))
|
||||
if allow_upsert:
|
||||
# Remove key columns from data
|
||||
upsert_keys = [key for key in data if key not in cls._key_columns]
|
||||
|
||||
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
|
||||
on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}"
|
||||
sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
|
||||
result = db.execute(
|
||||
sql,
|
||||
*[data[key] for key in data],
|
||||
)
|
||||
|
||||
if len(result) == 0:
|
||||
raise EntityError("Failed to create entity", status_code=500)
|
||||
|
||||
return cls._transform_rows(result)[0]
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def create_handler(request):
|
||||
data = await request.json()
|
||||
return cls.create(data)
|
||||
|
||||
cls._register_route(routes, "post", False, create_handler)
|
||||
|
||||
|
||||
class UpdateEntity(EntityCommon):
|
||||
@classmethod
|
||||
def update(cls, id: list, data: dict):
|
||||
id = cls._validate_id(id)
|
||||
data = cls._validate_data(data, allow_missing=True)
|
||||
|
||||
sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *"
|
||||
result = db.execute(
|
||||
sql,
|
||||
*[data[key] for key in data],
|
||||
*[id[key] for key in cls._key_columns],
|
||||
)
|
||||
|
||||
if len(result) == 0:
|
||||
raise EntityError("Failed to update entity", status_code=404)
|
||||
|
||||
return cls._transform_rows(result)[0]
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def update_handler(request):
|
||||
id = cls._extract_key(request)
|
||||
data = await request.json()
|
||||
return cls.update(id, data)
|
||||
|
||||
cls._register_route(routes, "patch", True, update_handler)
|
||||
|
||||
|
||||
class UpsertEntity(CreateEntity):
|
||||
@classmethod
|
||||
def upsert(cls, data: dict):
|
||||
return cls.create(data, allow_upsert=True)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def upsert_handler(request):
|
||||
data = await request.json()
|
||||
return cls.upsert(data)
|
||||
|
||||
cls._register_route(routes, "put", False, upsert_handler)
|
||||
|
||||
|
||||
class DeleteEntity(EntityCommon):
|
||||
@classmethod
|
||||
def delete(cls, id: list):
|
||||
id = cls._validate_id(id)
|
||||
db.execute(
|
||||
f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||
*[id[key] for key in cls._key_columns],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def delete_handler(request):
|
||||
id = cls._extract_key(request)
|
||||
cls.delete(id)
|
||||
|
||||
cls._register_route(routes, "delete", True, delete_handler)
|
||||
|
||||
|
||||
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Column:
|
||||
type: Any
|
||||
required: bool = False
|
||||
key: bool = False
|
||||
default: Any = None
|
||||
|
||||
|
||||
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
|
||||
return Column(type_, required, key, default)
|
||||
|
||||
|
||||
def table(table_name: str):
|
||||
def decorator(cls):
|
||||
# Store table name
|
||||
cls._table_name = table_name
|
||||
|
||||
# Process column definitions
|
||||
columns: dict[str, Column] = {}
|
||||
for attr_name, attr_value in cls.__dict__.items():
|
||||
if isinstance(attr_value, Column):
|
||||
columns[attr_name] = attr_value
|
||||
|
||||
# Store columns metadata
|
||||
cls._columns = columns
|
||||
cls._key_columns = [col for col in columns if columns[col].key]
|
||||
cls._column_csv = ", ".join([col for col in columns])
|
||||
cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns])
|
||||
|
||||
# Add initialization
|
||||
original_init = cls.__init__
|
||||
|
||||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
# Initialize columns with default values
|
||||
for col_name, col_def in cls._columns.items():
|
||||
setattr(self, col_name, col_def.default)
|
||||
# Call original init
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def test():
|
||||
@table("models")
|
||||
class Model(BaseEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
path: str = column(str, required=True)
|
||||
name: str = column(str, required=True)
|
||||
description: Optional[str] = column(str)
|
||||
architecture: Optional[str] = column(str)
|
||||
type: str = column(str, required=True)
|
||||
hash: Optional[str] = column(str)
|
||||
source_url: Optional[str] = column(str)
|
||||
|
||||
return Model
|
||||
|
||||
|
||||
@table("test")
|
||||
class Test(GetEntity, CreateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
|
||||
Model = test()
|
@@ -1,32 +0,0 @@
|
||||
from app.database.db import db
|
||||
from aiohttp import web
|
||||
|
||||
def create_routes(
|
||||
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
|
||||
):
|
||||
if get:
|
||||
@routes.get(f"/{prefix}/{table}")
|
||||
async def get_table(request):
|
||||
connection = db.get_connection()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table}")
|
||||
rows = cursor.fetchall()
|
||||
return web.json_response(rows)
|
||||
|
||||
if get_by_id:
|
||||
@routes.get(f"/{prefix}/{table}/{id}")
|
||||
async def get_table_by_id(request):
|
||||
connection = db.get_connection()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
|
||||
row = cursor.fetchone()
|
||||
return web.json_response(row)
|
||||
|
||||
if post:
|
||||
@routes.post(f"/{prefix}/{table}")
|
||||
async def post_table(request):
|
||||
data = await request.json()
|
||||
connection = db.get_connection()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
|
||||
return web.json_response({"status": "success"})
|
@@ -1,79 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from app.database.versions.v1 import v1
|
||||
|
||||
|
||||
class DatabaseUpdater:
|
||||
def __init__(self, connection, database_path):
|
||||
self.connection = connection
|
||||
self.database_path = database_path
|
||||
self.current_version = self.get_db_version()
|
||||
self.version_updates = {
|
||||
1: v1,
|
||||
}
|
||||
self.max_version = max(self.version_updates.keys())
|
||||
self.update_required = self.current_version < self.max_version
|
||||
logging.info(f"Database version: {self.current_version}")
|
||||
|
||||
def get_db_version(self):
|
||||
return self.connection.execute("PRAGMA user_version").fetchone()[0]
|
||||
|
||||
def backup(self):
|
||||
bkp_path = self.database_path + ".bkp"
|
||||
if os.path.exists(bkp_path):
|
||||
# TODO: auto-rollback failed upgrades
|
||||
raise Exception(
|
||||
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
|
||||
)
|
||||
|
||||
bkp = sqlite3.connect(bkp_path)
|
||||
self.connection.backup(bkp)
|
||||
bkp.close()
|
||||
logging.info("Database backup taken pre-upgrade.")
|
||||
return bkp_path
|
||||
|
||||
def update(self):
|
||||
if not self.update_required:
|
||||
return None
|
||||
|
||||
bkp_version = self.current_version
|
||||
bkp_path = None
|
||||
if self.current_version > 0:
|
||||
bkp_path = self.backup()
|
||||
|
||||
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
|
||||
|
||||
dirname = os.path.dirname(__file__)
|
||||
cursor = self.connection.cursor()
|
||||
for version in range(self.current_version + 1, self.max_version + 1):
|
||||
filename = os.path.join(dirname, f"versions/v{version}.sql")
|
||||
if not os.path.exists(filename):
|
||||
raise Exception(
|
||||
f"Database update script for version {version} not found"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
sql = file.read()
|
||||
cursor.executescript(sql)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to execute update script for version {version}: {e}"
|
||||
)
|
||||
|
||||
method = self.version_updates[version]
|
||||
if method is not None:
|
||||
method(cursor)
|
||||
|
||||
cursor.execute("PRAGMA user_version = %d" % self.max_version)
|
||||
self.connection.commit()
|
||||
cursor.close()
|
||||
self.current_version = self.get_db_version()
|
||||
|
||||
if bkp_path:
|
||||
# Keep a copy of the backup in case something goes wrong and we need to rollback
|
||||
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
|
||||
logging.info(f"Upgrade to successful.")
|
||||
|
||||
return (bkp_version, self.current_version)
|
@@ -1,17 +0,0 @@
|
||||
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
|
||||
|
||||
|
||||
def v1(cursor):
|
||||
print("Updating to v1")
|
||||
for folder_name in folder_names_and_paths.keys():
|
||||
if folder_name == "custom_nodes":
|
||||
continue
|
||||
|
||||
files = get_filename_list(folder_name)
|
||||
for file in files:
|
||||
file_path = get_full_path(folder_name, file)
|
||||
file_without_extension = file.rsplit(".", maxsplit=1)[0]
|
||||
cursor.execute(
|
||||
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
|
||||
(file_path, file_without_extension, folder_name),
|
||||
)
|
@@ -1,41 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
path TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
architecture TEXT,
|
||||
type TEXT NOT NULL,
|
||||
hash TEXT,
|
||||
source_url TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
model_tags (
|
||||
model_id INTEGER NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (model_id, tag_id),
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
INSERT INTO
|
||||
tags (name)
|
||||
VALUES
|
||||
('character'),
|
||||
('style'),
|
||||
('concept'),
|
||||
('clothing'),
|
||||
('poses'),
|
||||
('background'),
|
||||
('vehicle'),
|
||||
('buildings'),
|
||||
('objects'),
|
||||
('animal'),
|
||||
('action');
|
@@ -1,63 +0,0 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class ModelHasher:
|
||||
|
||||
def __init__(self):
|
||||
self._thread = None
|
||||
self._lock = threading.Lock()
|
||||
self._model_entity = None
|
||||
|
||||
def start(self, model_entity):
|
||||
if args.disable_model_hashing:
|
||||
return
|
||||
|
||||
self._model_entity = model_entity
|
||||
|
||||
if self._thread is None:
|
||||
# Lock to prevent multiple threads from starting
|
||||
with self._lock:
|
||||
if self._thread is None:
|
||||
self._thread = threading.Thread(target=self._hash_models)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
|
||||
def _get_models(self):
|
||||
models = self._model_entity.get("WHERE hash IS NULL")
|
||||
return models
|
||||
|
||||
def _hash_model(self, model_path):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
hash = h.hexdigest()
|
||||
return hash
|
||||
|
||||
def _hash_models(self):
|
||||
while True:
|
||||
models = self._get_models()
|
||||
|
||||
if len(models) == 0:
|
||||
break
|
||||
|
||||
for model in models:
|
||||
time.sleep(0)
|
||||
now = time.time()
|
||||
logging.info(f"Hashing model {model['path']}")
|
||||
hash = self._hash_model(model["path"])
|
||||
logging.info(
|
||||
f"Hashed model {model['path']} in {time.time() - now} seconds"
|
||||
)
|
||||
self._model_entity.update((model["id"],), {"hash": hash})
|
||||
|
||||
self._thread = None
|
||||
|
||||
|
||||
model_hasher = ModelHasher()
|
@@ -143,13 +143,9 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
|
||||
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-version",
|
||||
type=str,
|
||||
|
@@ -168,14 +168,18 @@ class Attention(nn.Module):
|
||||
k = self.to_k[1](k)
|
||||
v = self.to_v[1](v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
# apply_rotary_pos_emb inlined
|
||||
q_shape = q.shape
|
||||
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||
|
||||
def cal_attn(self, q, k, v, mask=None):
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
# apply_rotary_pos_emb inlined
|
||||
k_shape = k.shape
|
||||
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -191,7 +195,10 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
return self.cal_attn(q, k, v, mask)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x = x + extra_per_block_pos_emb
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
x,
|
||||
|
@@ -30,6 +30,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
from .patching import (
|
||||
Patcher,
|
||||
Patcher3D,
|
||||
@@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
|
||||
v, batch_size = time2batch(v)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
h_ = self.optimized_attention(q, k, v)
|
||||
|
||||
h_ = batch2time(h_, batch_size)
|
||||
h_ = self.proj_out(h_)
|
||||
@@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module):
|
||||
x = self.patcher3d(x)
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
h = self.down[i_level].block[i_block](h)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
@@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
|
||||
hh = hh.to(dtype=dtype)
|
||||
|
||||
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||
del x
|
||||
|
||||
# Height height transposed convolutions.
|
||||
xll = F.conv_transpose3d(
|
||||
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlll
|
||||
|
||||
xll += F.conv_transpose3d(
|
||||
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xllh
|
||||
|
||||
xlh = F.conv_transpose3d(
|
||||
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlhl
|
||||
|
||||
xlh += F.conv_transpose3d(
|
||||
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xlhh
|
||||
|
||||
xhl = F.conv_transpose3d(
|
||||
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhll
|
||||
|
||||
xhl += F.conv_transpose3d(
|
||||
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhlh
|
||||
|
||||
xhh = F.conv_transpose3d(
|
||||
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhhl
|
||||
|
||||
xhh += F.conv_transpose3d(
|
||||
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||
)
|
||||
del xhhh
|
||||
|
||||
# Handles width transposed convolutions.
|
||||
xl = F.conv_transpose3d(
|
||||
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xll
|
||||
|
||||
xl += F.conv_transpose3d(
|
||||
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xlh
|
||||
|
||||
xh = F.conv_transpose3d(
|
||||
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xhl
|
||||
|
||||
xh += F.conv_transpose3d(
|
||||
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||
)
|
||||
del xhh
|
||||
|
||||
# Handles time axis transposed convolutions.
|
||||
x = F.conv_transpose3d(
|
||||
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||
)
|
||||
del xl
|
||||
|
||||
x += F.conv_transpose3d(
|
||||
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||
)
|
||||
|
@@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.build_pos_embed(device=device)
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.block_x_format = block_x_format
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
@@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def build_pos_embed(self, device=None):
|
||||
def build_pos_embed(self, device=None, dtype=None):
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
@@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs,
|
||||
)
|
||||
@@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
|
||||
inputs["original_shape"],
|
||||
)
|
||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||
del inputs
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
@@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||
x = block(
|
||||
x,
|
||||
affline_emb_B_D,
|
||||
@@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
@@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
||||
len_w: int,
|
||||
len_t: int,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
||||
self.interpolation = interpolation
|
||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||
|
||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||
|
||||
|
||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||
|
@@ -89,8 +89,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
|
||||
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||
|
||||
num_parameters = sum(param.numel() for param in self.parameters())
|
||||
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||
logging.info(
|
||||
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||
logging.debug(
|
||||
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||
)
|
||||
|
||||
|
@@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
@@ -5,8 +5,15 @@ from torch import Tensor
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
|
@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
||||
return out
|
||||
|
||||
|
||||
def vae_attention():
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
return xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
return pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
return normal_attention
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
@@ -388,8 +388,8 @@ class VAE:
|
||||
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
||||
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
||||
#TODO: these values are a bit off because this is not a standard VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
|
@@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 2.0 #TODO
|
||||
memory_usage_factor = 1.8 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Cosmos1CV8x8x8
|
||||
|
||||
memory_usage_factor = 2.4 #TODO
|
||||
memory_usage_factor = 1.6 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||
|
||||
|
@@ -71,8 +71,8 @@ class CosmosImageToVideoLatent:
|
||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
out_latent["noise_mask"] = mask
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
|
||||
|
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.10"
|
||||
__version__ = "0.3.11"
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.10"
|
||||
version = "0.3.11"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@@ -2,6 +2,7 @@ torch
|
||||
torchsde
|
||||
torchvision
|
||||
torchaudio
|
||||
numpy>=1.25.0
|
||||
einops
|
||||
transformers>=4.28.1
|
||||
tokenizers>=0.13.3
|
||||
|
17
server.py
17
server.py
@@ -34,9 +34,6 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app.database.entities import get_entity, init_entities
|
||||
from app.database.db import db
|
||||
from app.model_hasher import model_hasher
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
@@ -685,25 +682,11 @@ class PromptServer():
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
def init_db(self, routes):
|
||||
init_entities(routes)
|
||||
models = get_entity("models")
|
||||
|
||||
if db.exists:
|
||||
model_hasher.start(models)
|
||||
else:
|
||||
def on_db_update(_, __):
|
||||
model_hasher.start(models)
|
||||
|
||||
db.register_after_update_callback(on_db_update)
|
||||
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
self.init_db(self.routes)
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
# This is very useful for frontend dev server, which need to forward
|
||||
|
@@ -1,513 +0,0 @@
|
||||
from comfy.cli_args import args
|
||||
|
||||
args.memory_database = True # force in-memory database for testing
|
||||
|
||||
from typing import Callable, Optional
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
from aiohttp import web
|
||||
from app.database.entities import (
|
||||
DeleteEntity,
|
||||
column,
|
||||
table,
|
||||
Column,
|
||||
GetEntity,
|
||||
GetEntityById,
|
||||
CreateEntity,
|
||||
UpsertEntity,
|
||||
UpdateEntity,
|
||||
)
|
||||
from app.database.db import db
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def create_table(entity):
|
||||
# reset db
|
||||
db.close()
|
||||
|
||||
cols: list[Column] = entity._columns
|
||||
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||||
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
|
||||
for col_name, col in cols.items():
|
||||
type = None
|
||||
if col.type == int:
|
||||
type = "INTEGER"
|
||||
elif col.type == str:
|
||||
type = "TEXT"
|
||||
|
||||
sql += f"{col_name} {type}"
|
||||
if col.required:
|
||||
sql += " NOT NULL"
|
||||
sql += ", "
|
||||
|
||||
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
|
||||
sql += ")"
|
||||
db.execute(sql)
|
||||
|
||||
|
||||
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||||
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||||
response = await method()
|
||||
assert mock.call_count == 1
|
||||
assert mock.call_args[0][0] == expected_sql
|
||||
assert mock.call_args[0][1:] == expected_args
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_entity():
|
||||
@table("getable_entity")
|
||||
class GetableEntity(GetEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return GetableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_by_id_entity():
|
||||
@table("getable_by_id_entity")
|
||||
class GetableByIdEntity(GetEntityById):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
return GetableByIdEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_by_id_composite_entity():
|
||||
@table("getable_by_id_composite_entity")
|
||||
class GetableByIdCompositeEntity(GetEntityById):
|
||||
id1: str = column(str, required=True, key=True)
|
||||
id2: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
return GetableByIdCompositeEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def creatable_entity():
|
||||
@table("creatable_entity")
|
||||
class CreatableEntity(CreateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
reqd: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return CreatableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upsertable_entity():
|
||||
@table("upsertable_entity")
|
||||
class UpsertableEntity(UpsertEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
reqd: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return UpsertableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def updateable_entity():
|
||||
@table("updateable_entity")
|
||||
class UpdateableEntity(UpdateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
reqd: str = column(str, required=True)
|
||||
|
||||
return UpdateableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deletable_entity():
|
||||
@table("deletable_entity")
|
||||
class DeletableEntity(DeleteEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
|
||||
return DeletableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deletable_composite_entity():
|
||||
@table("deletable_composite_entity")
|
||||
class DeletableCompositeEntity(DeleteEntity):
|
||||
id1: str = column(str, required=True, key=True)
|
||||
id2: int = column(int, required=True, key=True)
|
||||
|
||||
return DeletableCompositeEntity
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def entity(request):
|
||||
value = request.getfixturevalue(request.param)
|
||||
create_table(value)
|
||||
return value
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(aiohttp_client, app):
|
||||
return await aiohttp_client(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(entity):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
entity.register_route(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_empty_response(client):
|
||||
expected_sql = "SELECT * FROM getable_entity"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_data(client):
|
||||
# seed db
|
||||
db.execute(
|
||||
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
|
||||
)
|
||||
|
||||
expected_sql = "SELECT * FROM getable_entity"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1", "nullable": None},
|
||||
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_top_parameter(client):
|
||||
# seed with 3 rows
|
||||
db.execute(
|
||||
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
|
||||
)
|
||||
|
||||
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity?top=2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1", "nullable": None},
|
||||
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_invalid_top_parameter(client):
|
||||
response = await client.get("/db/getable_entity?top=hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid top parameter",
|
||||
"field": "top",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||
async def test_get_model_by_id_empty_response(client):
|
||||
# seed db
|
||||
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
|
||||
|
||||
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
|
||||
expected_args = (1,)
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_by_id_entity/1"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||
async def test_get_model_by_id_with_invalid_id(client):
|
||||
response = await client.get("/db/getable_by_id_entity/hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||
async def test_get_model_by_id_composite(client):
|
||||
# seed db
|
||||
db.execute(
|
||||
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
|
||||
)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||
)
|
||||
expected_args = ("one", 2)
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id1": "one", "id2": 2, "test": "test"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||
async def test_get_model_by_id_composite_with_invalid_id(client):
|
||||
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id2",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model(client):
|
||||
expected_sql = (
|
||||
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
|
||||
)
|
||||
expected_args = (1, "test1", "reqd1")
|
||||
response = await wrap_db(
|
||||
lambda: client.post(
|
||||
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||
),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_missing_required_field(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity", json={"id": 1, "test": "test1"}
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Missing field",
|
||||
"field": "reqd",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_missing_key_field(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Missing field",
|
||||
"field": "id",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_key_data(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={
|
||||
"id": "not_an_integer",
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
}, # id should be int
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "not_an_integer",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_data(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "aaa",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_type(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={
|
||||
"id": 1,
|
||||
"test": ["invalid_array"],
|
||||
"reqd": "reqd1",
|
||||
}, # test should be string
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "test",
|
||||
"value": ["invalid_array"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_name(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Unknown field",
|
||||
"field": "nonexistent_field",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
|
||||
async def test_upsert_model(client):
|
||||
expected_sql = (
|
||||
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
|
||||
"RETURNING *"
|
||||
)
|
||||
expected_args = (1, "test1", "reqd1")
|
||||
response = await wrap_db(
|
||||
lambda: client.put(
|
||||
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||
),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model(client):
|
||||
# seed db
|
||||
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
|
||||
|
||||
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
|
||||
expected_args = ("updated_test", 1)
|
||||
response = await wrap_db(
|
||||
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"reqd": "updated_test",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_null_required_field(client):
|
||||
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Required field",
|
||||
"field": "reqd",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_invalid_field(client):
|
||||
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Unknown field",
|
||||
"field": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_missing_record(client):
|
||||
response = await client.patch(
|
||||
"/db/updateable_entity/1", json={"reqd": "updated_test"}
|
||||
)
|
||||
|
||||
assert response.status == 404
|
||||
assert await response.json() == {
|
||||
"message": "Failed to update entity",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
|
||||
async def test_delete_model(client):
|
||||
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
|
||||
expected_args = (1,)
|
||||
response = await wrap_db(
|
||||
lambda: client.delete("/db/deletable_entity/1"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 204
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
|
||||
async def test_delete_model_composite_key(client):
|
||||
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||
expected_args = ("one", 2)
|
||||
response = await wrap_db(
|
||||
lambda: client.delete("/db/deletable_composite_entity/one/2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 204
|
Reference in New Issue
Block a user