mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Compare commits
3 Commits
v0.3.17
...
model_mana
Author | SHA1 | Date | |
---|---|---|---|
|
01110de8a3 | ||
|
785a220757 | ||
|
b6b475191d |
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
126
app/database/db.py
Normal file
126
app/database/db.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue, Empty, Full
|
||||
import threading
|
||||
from app.database.updater import DatabaseUpdater
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, database_path=None, pool_size=1):
|
||||
if database_path is None:
|
||||
self.exists = False
|
||||
database_path = "file::memory:?cache=shared"
|
||||
else:
|
||||
self.exists = os.path.exists(database_path)
|
||||
|
||||
self.database_path = database_path
|
||||
self.pool_size = pool_size
|
||||
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
|
||||
self.connection_pool: Queue = Queue(maxsize=pool_size)
|
||||
self._db_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
self._closing = False
|
||||
self._after_update_callbacks = []
|
||||
|
||||
def _setup(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
with self._db_lock:
|
||||
if not self._initialized:
|
||||
self._make_db()
|
||||
self._initialized = True
|
||||
|
||||
def _create_connection(self):
|
||||
# TODO: Catch error for sqlite lib missing on linux
|
||||
logging.info(f"Creating connection to {self.database_path}")
|
||||
conn = sqlite3.connect(
|
||||
self.database_path,
|
||||
check_same_thread=False,
|
||||
uri=self.database_path.startswith("file::"),
|
||||
)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
self.exists = True
|
||||
logging.info(f"Connected!")
|
||||
return conn
|
||||
|
||||
def _make_db(self):
|
||||
with self._get_connection() as con:
|
||||
updater = DatabaseUpdater(con, self.database_path)
|
||||
result = updater.update()
|
||||
if result is not None:
|
||||
old_version, new_version = result
|
||||
|
||||
for callback in self._after_update_callbacks:
|
||||
callback(old_version, new_version)
|
||||
|
||||
def _transform(self, row, columns):
|
||||
return {col.name: value for value, col in zip(row, columns)}
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
if self._closing:
|
||||
raise Exception("Database is shutting down")
|
||||
|
||||
try:
|
||||
# Try to get connection from pool
|
||||
connection = self.connection_pool.get_nowait()
|
||||
except Empty:
|
||||
# Create new connection if pool is empty
|
||||
connection = self._create_connection()
|
||||
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
try:
|
||||
# Try to add to pool if it's empty
|
||||
self.connection_pool.put_nowait(connection)
|
||||
except Full:
|
||||
# Pool is full, close the connection
|
||||
connection.close()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
# Setup the database if it's not already initialized
|
||||
self._setup()
|
||||
with self._get_connection() as connection:
|
||||
yield connection
|
||||
|
||||
def execute(self, sql, *args):
|
||||
with self.get_connection() as connection:
|
||||
cursor = connection.execute(sql, args)
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def register_after_update_callback(self, callback):
|
||||
self._after_update_callbacks.append(callback)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
# Drain and close all connections in the pool
|
||||
self._closing = True
|
||||
while True:
|
||||
try:
|
||||
conn = self.connection_pool.get_nowait()
|
||||
conn.close()
|
||||
except Empty:
|
||||
break
|
||||
self._closing = False
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Create a global instance
|
||||
db_path = None
|
||||
if not args.memory_database:
|
||||
db_path = folder_paths.get_user_directory() + "/comfyui.db"
|
||||
db = Database(db_path)
|
343
app/database/entities.py
Normal file
343
app/database/entities.py
Normal file
@@ -0,0 +1,343 @@
|
||||
from typing import Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from aiohttp import web
|
||||
from app.database.db import db
|
||||
|
||||
primitives = (bool, str, int, float, type(None))
|
||||
|
||||
|
||||
def is_primitive(obj):
|
||||
return isinstance(obj, primitives)
|
||||
|
||||
|
||||
class EntityError(Exception):
|
||||
def __init__(
|
||||
self, message: str, field: str = None, value: Any = None, status_code: int = 400
|
||||
):
|
||||
self.message = message
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_json(self):
|
||||
result = {"message": self.message}
|
||||
if self.field is not None:
|
||||
result["field"] = self.field
|
||||
if self.value is not None:
|
||||
result["value"] = self.value
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message} {self.field} {self.value}"
|
||||
|
||||
|
||||
class EntityCommon(dict):
|
||||
@classmethod
|
||||
def _get_route(cls, include_key: bool):
|
||||
route = f"/db/{cls._table_name}"
|
||||
if include_key:
|
||||
route += "".join([f"/{{{k}}}" for k in cls._key_columns])
|
||||
return route
|
||||
|
||||
@classmethod
|
||||
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
|
||||
route = cls._get_route(include_key)
|
||||
|
||||
@getattr(routes, verb)(route)
|
||||
async def _(request):
|
||||
try:
|
||||
data = await handler(request)
|
||||
if data is None:
|
||||
return web.json_response(status=204)
|
||||
|
||||
return web.json_response(data)
|
||||
except EntityError as e:
|
||||
return web.json_response(e.to_json(), status=e.status_code)
|
||||
|
||||
@classmethod
|
||||
def _transform(cls, row: list[Any]):
|
||||
return {col: value for col, value in zip(cls._columns, row)}
|
||||
|
||||
@classmethod
|
||||
def _transform_rows(cls, rows: list[list[Any]]):
|
||||
return [cls._transform(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
def _extract_key(cls, request):
|
||||
return {key: request.match_info.get(key, None) for key in cls._key_columns}
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
|
||||
result = {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise EntityError("Invalid data")
|
||||
|
||||
# Ensure all required fields are present
|
||||
for field in data:
|
||||
if field not in fields:
|
||||
raise EntityError("Unknown field", field)
|
||||
|
||||
for key in fields:
|
||||
col = cls._columns[key]
|
||||
if key not in data:
|
||||
if col.required and not allow_missing:
|
||||
raise EntityError("Missing field", key)
|
||||
else:
|
||||
# e.g. for updates, we allow missing fields
|
||||
continue
|
||||
elif data[key] is None and col.required:
|
||||
# Dont allow None for required fields
|
||||
raise EntityError("Required field", key)
|
||||
|
||||
# Validate data type
|
||||
value = data[key]
|
||||
|
||||
if value is not None and not is_primitive(value):
|
||||
raise EntityError("Invalid value", key, value)
|
||||
|
||||
try:
|
||||
type = col.type
|
||||
if value is not None and not isinstance(value, type):
|
||||
value = type(value)
|
||||
result[key] = value
|
||||
except Exception:
|
||||
raise EntityError("Invalid value", key, value)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _validate_id(cls, id: dict):
|
||||
return cls._validate(cls._key_columns, id)
|
||||
|
||||
@classmethod
|
||||
def _validate_data(cls, data: dict, allow_missing: bool = False):
|
||||
return cls._validate(cls._columns.keys(), data, allow_missing)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self._columns:
|
||||
self[name] = value
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self:
|
||||
return self[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
class GetEntity(EntityCommon):
|
||||
@classmethod
|
||||
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
|
||||
limit = ""
|
||||
if top is not None and isinstance(top, int):
|
||||
limit = f" LIMIT {top}"
|
||||
result = db.execute(
|
||||
f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}",
|
||||
)
|
||||
|
||||
# Map each row in result to an instance of the class
|
||||
return cls._transform_rows(result)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def get_handler(request):
|
||||
top = request.rel_url.query.get("top", None)
|
||||
if top is not None:
|
||||
try:
|
||||
top = int(top)
|
||||
except Exception:
|
||||
raise EntityError("Invalid top parameter", "top", top)
|
||||
return cls.get(top)
|
||||
|
||||
cls._register_route(routes, "get", False, get_handler)
|
||||
|
||||
|
||||
class GetEntityById(EntityCommon):
|
||||
@classmethod
|
||||
def get_by_id(cls, id: dict):
|
||||
id = cls._validate_id(id)
|
||||
|
||||
result = db.execute(
|
||||
f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||
*[id[key] for key in cls._key_columns],
|
||||
)
|
||||
|
||||
return cls._transform_rows(result)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def get_by_id_handler(request):
|
||||
id = cls._extract_key(request)
|
||||
return cls.get_by_id(id)
|
||||
|
||||
cls._register_route(routes, "get", True, get_by_id_handler)
|
||||
|
||||
|
||||
class CreateEntity(EntityCommon):
|
||||
@classmethod
|
||||
def create(cls, data: dict, allow_upsert: bool = False):
|
||||
data = cls._validate_data(data)
|
||||
values = ", ".join(["?"] * len(data))
|
||||
on_conflict = ""
|
||||
|
||||
data_keys = ", ".join(list(data.keys()))
|
||||
if allow_upsert:
|
||||
# Remove key columns from data
|
||||
upsert_keys = [key for key in data if key not in cls._key_columns]
|
||||
|
||||
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
|
||||
on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}"
|
||||
sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
|
||||
result = db.execute(
|
||||
sql,
|
||||
*[data[key] for key in data],
|
||||
)
|
||||
|
||||
if len(result) == 0:
|
||||
raise EntityError("Failed to create entity", status_code=500)
|
||||
|
||||
return cls._transform_rows(result)[0]
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def create_handler(request):
|
||||
data = await request.json()
|
||||
return cls.create(data)
|
||||
|
||||
cls._register_route(routes, "post", False, create_handler)
|
||||
|
||||
|
||||
class UpdateEntity(EntityCommon):
|
||||
@classmethod
|
||||
def update(cls, id: list, data: dict):
|
||||
id = cls._validate_id(id)
|
||||
data = cls._validate_data(data, allow_missing=True)
|
||||
|
||||
sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *"
|
||||
result = db.execute(
|
||||
sql,
|
||||
*[data[key] for key in data],
|
||||
*[id[key] for key in cls._key_columns],
|
||||
)
|
||||
|
||||
if len(result) == 0:
|
||||
raise EntityError("Failed to update entity", status_code=404)
|
||||
|
||||
return cls._transform_rows(result)[0]
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def update_handler(request):
|
||||
id = cls._extract_key(request)
|
||||
data = await request.json()
|
||||
return cls.update(id, data)
|
||||
|
||||
cls._register_route(routes, "patch", True, update_handler)
|
||||
|
||||
|
||||
class UpsertEntity(CreateEntity):
|
||||
@classmethod
|
||||
def upsert(cls, data: dict):
|
||||
return cls.create(data, allow_upsert=True)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def upsert_handler(request):
|
||||
data = await request.json()
|
||||
return cls.upsert(data)
|
||||
|
||||
cls._register_route(routes, "put", False, upsert_handler)
|
||||
|
||||
|
||||
class DeleteEntity(EntityCommon):
|
||||
@classmethod
|
||||
def delete(cls, id: list):
|
||||
id = cls._validate_id(id)
|
||||
db.execute(
|
||||
f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||
*[id[key] for key in cls._key_columns],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_route(cls, routes):
|
||||
async def delete_handler(request):
|
||||
id = cls._extract_key(request)
|
||||
cls.delete(id)
|
||||
|
||||
cls._register_route(routes, "delete", True, delete_handler)
|
||||
|
||||
|
||||
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Column:
|
||||
type: Any
|
||||
required: bool = False
|
||||
key: bool = False
|
||||
default: Any = None
|
||||
|
||||
|
||||
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
|
||||
return Column(type_, required, key, default)
|
||||
|
||||
|
||||
def table(table_name: str):
|
||||
def decorator(cls):
|
||||
# Store table name
|
||||
cls._table_name = table_name
|
||||
|
||||
# Process column definitions
|
||||
columns: dict[str, Column] = {}
|
||||
for attr_name, attr_value in cls.__dict__.items():
|
||||
if isinstance(attr_value, Column):
|
||||
columns[attr_name] = attr_value
|
||||
|
||||
# Store columns metadata
|
||||
cls._columns = columns
|
||||
cls._key_columns = [col for col in columns if columns[col].key]
|
||||
cls._column_csv = ", ".join([col for col in columns])
|
||||
cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns])
|
||||
|
||||
# Add initialization
|
||||
original_init = cls.__init__
|
||||
|
||||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
# Initialize columns with default values
|
||||
for col_name, col_def in cls._columns.items():
|
||||
setattr(self, col_name, col_def.default)
|
||||
# Call original init
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def test():
|
||||
@table("models")
|
||||
class Model(BaseEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
path: str = column(str, required=True)
|
||||
name: str = column(str, required=True)
|
||||
description: Optional[str] = column(str)
|
||||
architecture: Optional[str] = column(str)
|
||||
type: str = column(str, required=True)
|
||||
hash: Optional[str] = column(str)
|
||||
source_url: Optional[str] = column(str)
|
||||
|
||||
return Model
|
||||
|
||||
|
||||
@table("test")
|
||||
class Test(GetEntity, CreateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
|
||||
Model = test()
|
32
app/database/routes.py
Normal file
32
app/database/routes.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from app.database.db import db
|
||||
from aiohttp import web
|
||||
|
||||
def create_routes(
|
||||
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
|
||||
):
|
||||
if get:
|
||||
@routes.get(f"/{prefix}/{table}")
|
||||
async def get_table(request):
|
||||
connection = db.get_connection()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table}")
|
||||
rows = cursor.fetchall()
|
||||
return web.json_response(rows)
|
||||
|
||||
if get_by_id:
|
||||
@routes.get(f"/{prefix}/{table}/{id}")
|
||||
async def get_table_by_id(request):
|
||||
connection = db.get_connection()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
|
||||
row = cursor.fetchone()
|
||||
return web.json_response(row)
|
||||
|
||||
if post:
|
||||
@routes.post(f"/{prefix}/{table}")
|
||||
async def post_table(request):
|
||||
data = await request.json()
|
||||
connection = db.get_connection()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
|
||||
return web.json_response({"status": "success"})
|
79
app/database/updater.py
Normal file
79
app/database/updater.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from app.database.versions.v1 import v1
|
||||
|
||||
|
||||
class DatabaseUpdater:
|
||||
def __init__(self, connection, database_path):
|
||||
self.connection = connection
|
||||
self.database_path = database_path
|
||||
self.current_version = self.get_db_version()
|
||||
self.version_updates = {
|
||||
1: v1,
|
||||
}
|
||||
self.max_version = max(self.version_updates.keys())
|
||||
self.update_required = self.current_version < self.max_version
|
||||
logging.info(f"Database version: {self.current_version}")
|
||||
|
||||
def get_db_version(self):
|
||||
return self.connection.execute("PRAGMA user_version").fetchone()[0]
|
||||
|
||||
def backup(self):
|
||||
bkp_path = self.database_path + ".bkp"
|
||||
if os.path.exists(bkp_path):
|
||||
# TODO: auto-rollback failed upgrades
|
||||
raise Exception(
|
||||
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
|
||||
)
|
||||
|
||||
bkp = sqlite3.connect(bkp_path)
|
||||
self.connection.backup(bkp)
|
||||
bkp.close()
|
||||
logging.info("Database backup taken pre-upgrade.")
|
||||
return bkp_path
|
||||
|
||||
def update(self):
|
||||
if not self.update_required:
|
||||
return None
|
||||
|
||||
bkp_version = self.current_version
|
||||
bkp_path = None
|
||||
if self.current_version > 0:
|
||||
bkp_path = self.backup()
|
||||
|
||||
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
|
||||
|
||||
dirname = os.path.dirname(__file__)
|
||||
cursor = self.connection.cursor()
|
||||
for version in range(self.current_version + 1, self.max_version + 1):
|
||||
filename = os.path.join(dirname, f"versions/v{version}.sql")
|
||||
if not os.path.exists(filename):
|
||||
raise Exception(
|
||||
f"Database update script for version {version} not found"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
sql = file.read()
|
||||
cursor.executescript(sql)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to execute update script for version {version}: {e}"
|
||||
)
|
||||
|
||||
method = self.version_updates[version]
|
||||
if method is not None:
|
||||
method(cursor)
|
||||
|
||||
cursor.execute("PRAGMA user_version = %d" % self.max_version)
|
||||
self.connection.commit()
|
||||
cursor.close()
|
||||
self.current_version = self.get_db_version()
|
||||
|
||||
if bkp_path:
|
||||
# Keep a copy of the backup in case something goes wrong and we need to rollback
|
||||
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
|
||||
logging.info(f"Upgrade to successful.")
|
||||
|
||||
return (bkp_version, self.current_version)
|
17
app/database/versions/v1.py
Normal file
17
app/database/versions/v1.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
|
||||
|
||||
|
||||
def v1(cursor):
|
||||
print("Updating to v1")
|
||||
for folder_name in folder_names_and_paths.keys():
|
||||
if folder_name == "custom_nodes":
|
||||
continue
|
||||
|
||||
files = get_filename_list(folder_name)
|
||||
for file in files:
|
||||
file_path = get_full_path(folder_name, file)
|
||||
file_without_extension = file.rsplit(".", maxsplit=1)[0]
|
||||
cursor.execute(
|
||||
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
|
||||
(file_path, file_without_extension, folder_name),
|
||||
)
|
41
app/database/versions/v1.sql
Normal file
41
app/database/versions/v1.sql
Normal file
@@ -0,0 +1,41 @@
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
path TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
architecture TEXT,
|
||||
type TEXT NOT NULL,
|
||||
hash TEXT,
|
||||
source_url TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
model_tags (
|
||||
model_id INTEGER NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (model_id, tag_id),
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
INSERT INTO
|
||||
tags (name)
|
||||
VALUES
|
||||
('character'),
|
||||
('style'),
|
||||
('concept'),
|
||||
('clothing'),
|
||||
('poses'),
|
||||
('background'),
|
||||
('vehicle'),
|
||||
('buildings'),
|
||||
('objects'),
|
||||
('animal'),
|
||||
('action');
|
63
app/model_hasher.py
Normal file
63
app/model_hasher.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class ModelHasher:
|
||||
|
||||
def __init__(self):
|
||||
self._thread = None
|
||||
self._lock = threading.Lock()
|
||||
self._model_entity = None
|
||||
|
||||
def start(self, model_entity):
|
||||
if args.disable_model_hashing:
|
||||
return
|
||||
|
||||
self._model_entity = model_entity
|
||||
|
||||
if self._thread is None:
|
||||
# Lock to prevent multiple threads from starting
|
||||
with self._lock:
|
||||
if self._thread is None:
|
||||
self._thread = threading.Thread(target=self._hash_models)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
|
||||
def _get_models(self):
|
||||
models = self._model_entity.get("WHERE hash IS NULL")
|
||||
return models
|
||||
|
||||
def _hash_model(self, model_path):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
hash = h.hexdigest()
|
||||
return hash
|
||||
|
||||
def _hash_models(self):
|
||||
while True:
|
||||
models = self._get_models()
|
||||
|
||||
if len(models) == 0:
|
||||
break
|
||||
|
||||
for model in models:
|
||||
time.sleep(0)
|
||||
now = time.time()
|
||||
logging.info(f"Hashing model {model['path']}")
|
||||
hash = self._hash_model(model["path"])
|
||||
logging.info(
|
||||
f"Hashed model {model['path']} in {time.time() - now} seconds"
|
||||
)
|
||||
self._model_entity.update((model["id"],), {"hash": hash})
|
||||
|
||||
self._thread = None
|
||||
|
||||
|
||||
model_hasher = ModelHasher()
|
@@ -143,9 +143,13 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
|
||||
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-version",
|
||||
type=str,
|
||||
|
17
server.py
17
server.py
@@ -34,6 +34,9 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app.database.entities import get_entity, init_entities
|
||||
from app.database.db import db
|
||||
from app.model_hasher import model_hasher
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
@@ -682,11 +685,25 @@ class PromptServer():
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
def init_db(self, routes):
|
||||
init_entities(routes)
|
||||
models = get_entity("models")
|
||||
|
||||
if db.exists:
|
||||
model_hasher.start(models)
|
||||
else:
|
||||
def on_db_update(_, __):
|
||||
model_hasher.start(models)
|
||||
|
||||
db.register_after_update_callback(on_db_update)
|
||||
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
self.init_db(self.routes)
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
# This is very useful for frontend dev server, which need to forward
|
||||
|
513
tests-unit/app_test/entities_test.py
Normal file
513
tests-unit/app_test/entities_test.py
Normal file
@@ -0,0 +1,513 @@
|
||||
from comfy.cli_args import args
|
||||
|
||||
args.memory_database = True # force in-memory database for testing
|
||||
|
||||
from typing import Callable, Optional
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
from aiohttp import web
|
||||
from app.database.entities import (
|
||||
DeleteEntity,
|
||||
column,
|
||||
table,
|
||||
Column,
|
||||
GetEntity,
|
||||
GetEntityById,
|
||||
CreateEntity,
|
||||
UpsertEntity,
|
||||
UpdateEntity,
|
||||
)
|
||||
from app.database.db import db
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def create_table(entity):
|
||||
# reset db
|
||||
db.close()
|
||||
|
||||
cols: list[Column] = entity._columns
|
||||
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||||
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
|
||||
for col_name, col in cols.items():
|
||||
type = None
|
||||
if col.type == int:
|
||||
type = "INTEGER"
|
||||
elif col.type == str:
|
||||
type = "TEXT"
|
||||
|
||||
sql += f"{col_name} {type}"
|
||||
if col.required:
|
||||
sql += " NOT NULL"
|
||||
sql += ", "
|
||||
|
||||
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
|
||||
sql += ")"
|
||||
db.execute(sql)
|
||||
|
||||
|
||||
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||||
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||||
response = await method()
|
||||
assert mock.call_count == 1
|
||||
assert mock.call_args[0][0] == expected_sql
|
||||
assert mock.call_args[0][1:] == expected_args
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_entity():
|
||||
@table("getable_entity")
|
||||
class GetableEntity(GetEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return GetableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_by_id_entity():
|
||||
@table("getable_by_id_entity")
|
||||
class GetableByIdEntity(GetEntityById):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
return GetableByIdEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_by_id_composite_entity():
|
||||
@table("getable_by_id_composite_entity")
|
||||
class GetableByIdCompositeEntity(GetEntityById):
|
||||
id1: str = column(str, required=True, key=True)
|
||||
id2: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
return GetableByIdCompositeEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def creatable_entity():
|
||||
@table("creatable_entity")
|
||||
class CreatableEntity(CreateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
reqd: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return CreatableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upsertable_entity():
|
||||
@table("upsertable_entity")
|
||||
class UpsertableEntity(UpsertEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
reqd: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return UpsertableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def updateable_entity():
|
||||
@table("updateable_entity")
|
||||
class UpdateableEntity(UpdateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
reqd: str = column(str, required=True)
|
||||
|
||||
return UpdateableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deletable_entity():
|
||||
@table("deletable_entity")
|
||||
class DeletableEntity(DeleteEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
|
||||
return DeletableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deletable_composite_entity():
|
||||
@table("deletable_composite_entity")
|
||||
class DeletableCompositeEntity(DeleteEntity):
|
||||
id1: str = column(str, required=True, key=True)
|
||||
id2: int = column(int, required=True, key=True)
|
||||
|
||||
return DeletableCompositeEntity
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def entity(request):
|
||||
value = request.getfixturevalue(request.param)
|
||||
create_table(value)
|
||||
return value
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(aiohttp_client, app):
|
||||
return await aiohttp_client(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(entity):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
entity.register_route(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_empty_response(client):
|
||||
expected_sql = "SELECT * FROM getable_entity"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_data(client):
|
||||
# seed db
|
||||
db.execute(
|
||||
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
|
||||
)
|
||||
|
||||
expected_sql = "SELECT * FROM getable_entity"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1", "nullable": None},
|
||||
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_top_parameter(client):
|
||||
# seed with 3 rows
|
||||
db.execute(
|
||||
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
|
||||
)
|
||||
|
||||
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity?top=2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1", "nullable": None},
|
||||
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_invalid_top_parameter(client):
|
||||
response = await client.get("/db/getable_entity?top=hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid top parameter",
|
||||
"field": "top",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||
async def test_get_model_by_id_empty_response(client):
|
||||
# seed db
|
||||
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
|
||||
|
||||
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
|
||||
expected_args = (1,)
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_by_id_entity/1"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||
async def test_get_model_by_id_with_invalid_id(client):
|
||||
response = await client.get("/db/getable_by_id_entity/hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||
async def test_get_model_by_id_composite(client):
|
||||
# seed db
|
||||
db.execute(
|
||||
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
|
||||
)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||
)
|
||||
expected_args = ("one", 2)
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id1": "one", "id2": 2, "test": "test"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||
async def test_get_model_by_id_composite_with_invalid_id(client):
|
||||
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id2",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model(client):
|
||||
expected_sql = (
|
||||
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
|
||||
)
|
||||
expected_args = (1, "test1", "reqd1")
|
||||
response = await wrap_db(
|
||||
lambda: client.post(
|
||||
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||
),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_missing_required_field(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity", json={"id": 1, "test": "test1"}
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Missing field",
|
||||
"field": "reqd",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_missing_key_field(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Missing field",
|
||||
"field": "id",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_key_data(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={
|
||||
"id": "not_an_integer",
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
}, # id should be int
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "not_an_integer",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_data(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "aaa",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_type(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={
|
||||
"id": 1,
|
||||
"test": ["invalid_array"],
|
||||
"reqd": "reqd1",
|
||||
}, # test should be string
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "test",
|
||||
"value": ["invalid_array"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_name(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Unknown field",
|
||||
"field": "nonexistent_field",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
|
||||
async def test_upsert_model(client):
|
||||
expected_sql = (
|
||||
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
|
||||
"RETURNING *"
|
||||
)
|
||||
expected_args = (1, "test1", "reqd1")
|
||||
response = await wrap_db(
|
||||
lambda: client.put(
|
||||
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||
),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model(client):
|
||||
# seed db
|
||||
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
|
||||
|
||||
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
|
||||
expected_args = ("updated_test", 1)
|
||||
response = await wrap_db(
|
||||
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"reqd": "updated_test",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_null_required_field(client):
|
||||
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Required field",
|
||||
"field": "reqd",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_invalid_field(client):
|
||||
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Unknown field",
|
||||
"field": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_missing_record(client):
|
||||
response = await client.patch(
|
||||
"/db/updateable_entity/1", json={"reqd": "updated_test"}
|
||||
)
|
||||
|
||||
assert response.status == 404
|
||||
assert await response.json() == {
|
||||
"message": "Failed to update entity",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
|
||||
async def test_delete_model(client):
|
||||
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
|
||||
expected_args = (1,)
|
||||
response = await wrap_db(
|
||||
lambda: client.delete("/db/deletable_entity/1"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 204
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
|
||||
async def test_delete_model_composite_key(client):
|
||||
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||
expected_args = ("one", 2)
|
||||
response = await wrap_db(
|
||||
lambda: client.delete("/db/deletable_composite_entity/one/2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 204
|
Reference in New Issue
Block a user