mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
3 Commits
5d5024296d
...
model_mana
Author | SHA1 | Date | |
---|---|---|---|
|
01110de8a3 | ||
|
785a220757 | ||
|
b6b475191d |
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
126
app/database/db.py
Normal file
126
app/database/db.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from queue import Queue, Empty, Full
|
||||||
|
import threading
|
||||||
|
from app.database.updater import DatabaseUpdater
|
||||||
|
import folder_paths
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
def __init__(self, database_path=None, pool_size=1):
|
||||||
|
if database_path is None:
|
||||||
|
self.exists = False
|
||||||
|
database_path = "file::memory:?cache=shared"
|
||||||
|
else:
|
||||||
|
self.exists = os.path.exists(database_path)
|
||||||
|
|
||||||
|
self.database_path = database_path
|
||||||
|
self.pool_size = pool_size
|
||||||
|
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
|
||||||
|
self.connection_pool: Queue = Queue(maxsize=pool_size)
|
||||||
|
self._db_lock = threading.Lock()
|
||||||
|
self._initialized = False
|
||||||
|
self._closing = False
|
||||||
|
self._after_update_callbacks = []
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._db_lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._make_db()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def _create_connection(self):
|
||||||
|
# TODO: Catch error for sqlite lib missing on linux
|
||||||
|
logging.info(f"Creating connection to {self.database_path}")
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
self.database_path,
|
||||||
|
check_same_thread=False,
|
||||||
|
uri=self.database_path.startswith("file::"),
|
||||||
|
)
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
self.exists = True
|
||||||
|
logging.info(f"Connected!")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _make_db(self):
|
||||||
|
with self._get_connection() as con:
|
||||||
|
updater = DatabaseUpdater(con, self.database_path)
|
||||||
|
result = updater.update()
|
||||||
|
if result is not None:
|
||||||
|
old_version, new_version = result
|
||||||
|
|
||||||
|
for callback in self._after_update_callbacks:
|
||||||
|
callback(old_version, new_version)
|
||||||
|
|
||||||
|
def _transform(self, row, columns):
|
||||||
|
return {col.name: value for value, col in zip(row, columns)}
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _get_connection(self):
|
||||||
|
if self._closing:
|
||||||
|
raise Exception("Database is shutting down")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get connection from pool
|
||||||
|
connection = self.connection_pool.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
# Create new connection if pool is empty
|
||||||
|
connection = self._create_connection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield connection
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
# Try to add to pool if it's empty
|
||||||
|
self.connection_pool.put_nowait(connection)
|
||||||
|
except Full:
|
||||||
|
# Pool is full, close the connection
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_connection(self):
|
||||||
|
# Setup the database if it's not already initialized
|
||||||
|
self._setup()
|
||||||
|
with self._get_connection() as connection:
|
||||||
|
yield connection
|
||||||
|
|
||||||
|
def execute(self, sql, *args):
|
||||||
|
with self.get_connection() as connection:
|
||||||
|
cursor = connection.execute(sql, args)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def register_after_update_callback(self, callback):
|
||||||
|
self._after_update_callbacks.append(callback)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._closing:
|
||||||
|
return
|
||||||
|
# Drain and close all connections in the pool
|
||||||
|
self._closing = True
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
conn = self.connection_pool.get_nowait()
|
||||||
|
conn.close()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
self._closing = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Create a global instance
|
||||||
|
db_path = None
|
||||||
|
if not args.memory_database:
|
||||||
|
db_path = folder_paths.get_user_directory() + "/comfyui.db"
|
||||||
|
db = Database(db_path)
|
343
app/database/entities.py
Normal file
343
app/database/entities.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
from typing import Optional, Any, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import wraps
|
||||||
|
from aiohttp import web
|
||||||
|
from app.database.db import db
|
||||||
|
|
||||||
|
primitives = (bool, str, int, float, type(None))
|
||||||
|
|
||||||
|
|
||||||
|
def is_primitive(obj):
|
||||||
|
return isinstance(obj, primitives)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self, message: str, field: str = None, value: Any = None, status_code: int = 400
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.field = field
|
||||||
|
self.value = value
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
result = {"message": self.message}
|
||||||
|
if self.field is not None:
|
||||||
|
result["field"] = self.field
|
||||||
|
if self.value is not None:
|
||||||
|
result["value"] = self.value
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.message} {self.field} {self.value}"
|
||||||
|
|
||||||
|
|
||||||
|
class EntityCommon(dict):
|
||||||
|
@classmethod
|
||||||
|
def _get_route(cls, include_key: bool):
|
||||||
|
route = f"/db/{cls._table_name}"
|
||||||
|
if include_key:
|
||||||
|
route += "".join([f"/{{{k}}}" for k in cls._key_columns])
|
||||||
|
return route
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
|
||||||
|
route = cls._get_route(include_key)
|
||||||
|
|
||||||
|
@getattr(routes, verb)(route)
|
||||||
|
async def _(request):
|
||||||
|
try:
|
||||||
|
data = await handler(request)
|
||||||
|
if data is None:
|
||||||
|
return web.json_response(status=204)
|
||||||
|
|
||||||
|
return web.json_response(data)
|
||||||
|
except EntityError as e:
|
||||||
|
return web.json_response(e.to_json(), status=e.status_code)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _transform(cls, row: list[Any]):
|
||||||
|
return {col: value for col, value in zip(cls._columns, row)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _transform_rows(cls, rows: list[list[Any]]):
|
||||||
|
return [cls._transform(row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_key(cls, request):
|
||||||
|
return {key: request.match_info.get(key, None) for key in cls._key_columns}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise EntityError("Invalid data")
|
||||||
|
|
||||||
|
# Ensure all required fields are present
|
||||||
|
for field in data:
|
||||||
|
if field not in fields:
|
||||||
|
raise EntityError("Unknown field", field)
|
||||||
|
|
||||||
|
for key in fields:
|
||||||
|
col = cls._columns[key]
|
||||||
|
if key not in data:
|
||||||
|
if col.required and not allow_missing:
|
||||||
|
raise EntityError("Missing field", key)
|
||||||
|
else:
|
||||||
|
# e.g. for updates, we allow missing fields
|
||||||
|
continue
|
||||||
|
elif data[key] is None and col.required:
|
||||||
|
# Dont allow None for required fields
|
||||||
|
raise EntityError("Required field", key)
|
||||||
|
|
||||||
|
# Validate data type
|
||||||
|
value = data[key]
|
||||||
|
|
||||||
|
if value is not None and not is_primitive(value):
|
||||||
|
raise EntityError("Invalid value", key, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
type = col.type
|
||||||
|
if value is not None and not isinstance(value, type):
|
||||||
|
value = type(value)
|
||||||
|
result[key] = value
|
||||||
|
except Exception:
|
||||||
|
raise EntityError("Invalid value", key, value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_id(cls, id: dict):
|
||||||
|
return cls._validate(cls._key_columns, id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_data(cls, data: dict, allow_missing: bool = False):
|
||||||
|
return cls._validate(cls._columns.keys(), data, allow_missing)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if name in self._columns:
|
||||||
|
self[name] = value
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self:
|
||||||
|
return self[name]
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||||
|
|
||||||
|
|
||||||
|
class GetEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
|
||||||
|
limit = ""
|
||||||
|
if top is not None and isinstance(top, int):
|
||||||
|
limit = f" LIMIT {top}"
|
||||||
|
result = db.execute(
|
||||||
|
f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map each row in result to an instance of the class
|
||||||
|
return cls._transform_rows(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def get_handler(request):
|
||||||
|
top = request.rel_url.query.get("top", None)
|
||||||
|
if top is not None:
|
||||||
|
try:
|
||||||
|
top = int(top)
|
||||||
|
except Exception:
|
||||||
|
raise EntityError("Invalid top parameter", "top", top)
|
||||||
|
return cls.get(top)
|
||||||
|
|
||||||
|
cls._register_route(routes, "get", False, get_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class GetEntityById(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def get_by_id(cls, id: dict):
|
||||||
|
id = cls._validate_id(id)
|
||||||
|
|
||||||
|
result = db.execute(
|
||||||
|
f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||||
|
*[id[key] for key in cls._key_columns],
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls._transform_rows(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def get_by_id_handler(request):
|
||||||
|
id = cls._extract_key(request)
|
||||||
|
return cls.get_by_id(id)
|
||||||
|
|
||||||
|
cls._register_route(routes, "get", True, get_by_id_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, data: dict, allow_upsert: bool = False):
|
||||||
|
data = cls._validate_data(data)
|
||||||
|
values = ", ".join(["?"] * len(data))
|
||||||
|
on_conflict = ""
|
||||||
|
|
||||||
|
data_keys = ", ".join(list(data.keys()))
|
||||||
|
if allow_upsert:
|
||||||
|
# Remove key columns from data
|
||||||
|
upsert_keys = [key for key in data if key not in cls._key_columns]
|
||||||
|
|
||||||
|
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
|
||||||
|
on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}"
|
||||||
|
sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
|
||||||
|
result = db.execute(
|
||||||
|
sql,
|
||||||
|
*[data[key] for key in data],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result) == 0:
|
||||||
|
raise EntityError("Failed to create entity", status_code=500)
|
||||||
|
|
||||||
|
return cls._transform_rows(result)[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def create_handler(request):
|
||||||
|
data = await request.json()
|
||||||
|
return cls.create(data)
|
||||||
|
|
||||||
|
cls._register_route(routes, "post", False, create_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def update(cls, id: list, data: dict):
|
||||||
|
id = cls._validate_id(id)
|
||||||
|
data = cls._validate_data(data, allow_missing=True)
|
||||||
|
|
||||||
|
sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *"
|
||||||
|
result = db.execute(
|
||||||
|
sql,
|
||||||
|
*[data[key] for key in data],
|
||||||
|
*[id[key] for key in cls._key_columns],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result) == 0:
|
||||||
|
raise EntityError("Failed to update entity", status_code=404)
|
||||||
|
|
||||||
|
return cls._transform_rows(result)[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def update_handler(request):
|
||||||
|
id = cls._extract_key(request)
|
||||||
|
data = await request.json()
|
||||||
|
return cls.update(id, data)
|
||||||
|
|
||||||
|
cls._register_route(routes, "patch", True, update_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsertEntity(CreateEntity):
|
||||||
|
@classmethod
|
||||||
|
def upsert(cls, data: dict):
|
||||||
|
return cls.create(data, allow_upsert=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def upsert_handler(request):
|
||||||
|
data = await request.json()
|
||||||
|
return cls.upsert(data)
|
||||||
|
|
||||||
|
cls._register_route(routes, "put", False, upsert_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def delete(cls, id: list):
|
||||||
|
id = cls._validate_id(id)
|
||||||
|
db.execute(
|
||||||
|
f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||||
|
*[id[key] for key in cls._key_columns],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def delete_handler(request):
|
||||||
|
id = cls._extract_key(request)
|
||||||
|
cls.delete(id)
|
||||||
|
|
||||||
|
cls._register_route(routes, "delete", True, delete_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Column:
|
||||||
|
type: Any
|
||||||
|
required: bool = False
|
||||||
|
key: bool = False
|
||||||
|
default: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
|
||||||
|
return Column(type_, required, key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def table(table_name: str):
|
||||||
|
def decorator(cls):
|
||||||
|
# Store table name
|
||||||
|
cls._table_name = table_name
|
||||||
|
|
||||||
|
# Process column definitions
|
||||||
|
columns: dict[str, Column] = {}
|
||||||
|
for attr_name, attr_value in cls.__dict__.items():
|
||||||
|
if isinstance(attr_value, Column):
|
||||||
|
columns[attr_name] = attr_value
|
||||||
|
|
||||||
|
# Store columns metadata
|
||||||
|
cls._columns = columns
|
||||||
|
cls._key_columns = [col for col in columns if columns[col].key]
|
||||||
|
cls._column_csv = ", ".join([col for col in columns])
|
||||||
|
cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns])
|
||||||
|
|
||||||
|
# Add initialization
|
||||||
|
original_init = cls.__init__
|
||||||
|
|
||||||
|
@wraps(original_init)
|
||||||
|
def new_init(self, *args, **kwargs):
|
||||||
|
# Initialize columns with default values
|
||||||
|
for col_name, col_def in cls._columns.items():
|
||||||
|
setattr(self, col_name, col_def.default)
|
||||||
|
# Call original init
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
cls.__init__ = new_init
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
@table("models")
|
||||||
|
class Model(BaseEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
path: str = column(str, required=True)
|
||||||
|
name: str = column(str, required=True)
|
||||||
|
description: Optional[str] = column(str)
|
||||||
|
architecture: Optional[str] = column(str)
|
||||||
|
type: str = column(str, required=True)
|
||||||
|
hash: Optional[str] = column(str)
|
||||||
|
source_url: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return Model
|
||||||
|
|
||||||
|
|
||||||
|
@table("test")
|
||||||
|
class Test(GetEntity, CreateEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
|
||||||
|
|
||||||
|
Model = test()
|
32
app/database/routes.py
Normal file
32
app/database/routes.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from app.database.db import db
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
def create_routes(
|
||||||
|
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
|
||||||
|
):
|
||||||
|
if get:
|
||||||
|
@routes.get(f"/{prefix}/{table}")
|
||||||
|
async def get_table(request):
|
||||||
|
connection = db.get_connection()
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM {table}")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
return web.json_response(rows)
|
||||||
|
|
||||||
|
if get_by_id:
|
||||||
|
@routes.get(f"/{prefix}/{table}/{id}")
|
||||||
|
async def get_table_by_id(request):
|
||||||
|
connection = db.get_connection()
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return web.json_response(row)
|
||||||
|
|
||||||
|
if post:
|
||||||
|
@routes.post(f"/{prefix}/{table}")
|
||||||
|
async def post_table(request):
|
||||||
|
data = await request.json()
|
||||||
|
connection = db.get_connection()
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
|
||||||
|
return web.json_response({"status": "success"})
|
79
app/database/updater.py
Normal file
79
app/database/updater.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from app.database.versions.v1 import v1
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseUpdater:
|
||||||
|
def __init__(self, connection, database_path):
|
||||||
|
self.connection = connection
|
||||||
|
self.database_path = database_path
|
||||||
|
self.current_version = self.get_db_version()
|
||||||
|
self.version_updates = {
|
||||||
|
1: v1,
|
||||||
|
}
|
||||||
|
self.max_version = max(self.version_updates.keys())
|
||||||
|
self.update_required = self.current_version < self.max_version
|
||||||
|
logging.info(f"Database version: {self.current_version}")
|
||||||
|
|
||||||
|
def get_db_version(self):
|
||||||
|
return self.connection.execute("PRAGMA user_version").fetchone()[0]
|
||||||
|
|
||||||
|
def backup(self):
|
||||||
|
bkp_path = self.database_path + ".bkp"
|
||||||
|
if os.path.exists(bkp_path):
|
||||||
|
# TODO: auto-rollback failed upgrades
|
||||||
|
raise Exception(
|
||||||
|
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
bkp = sqlite3.connect(bkp_path)
|
||||||
|
self.connection.backup(bkp)
|
||||||
|
bkp.close()
|
||||||
|
logging.info("Database backup taken pre-upgrade.")
|
||||||
|
return bkp_path
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if not self.update_required:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bkp_version = self.current_version
|
||||||
|
bkp_path = None
|
||||||
|
if self.current_version > 0:
|
||||||
|
bkp_path = self.backup()
|
||||||
|
|
||||||
|
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
|
||||||
|
|
||||||
|
dirname = os.path.dirname(__file__)
|
||||||
|
cursor = self.connection.cursor()
|
||||||
|
for version in range(self.current_version + 1, self.max_version + 1):
|
||||||
|
filename = os.path.join(dirname, f"versions/v{version}.sql")
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
raise Exception(
|
||||||
|
f"Database update script for version {version} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
sql = file.read()
|
||||||
|
cursor.executescript(sql)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to execute update script for version {version}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
method = self.version_updates[version]
|
||||||
|
if method is not None:
|
||||||
|
method(cursor)
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA user_version = %d" % self.max_version)
|
||||||
|
self.connection.commit()
|
||||||
|
cursor.close()
|
||||||
|
self.current_version = self.get_db_version()
|
||||||
|
|
||||||
|
if bkp_path:
|
||||||
|
# Keep a copy of the backup in case something goes wrong and we need to rollback
|
||||||
|
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
|
||||||
|
logging.info(f"Upgrade to successful.")
|
||||||
|
|
||||||
|
return (bkp_version, self.current_version)
|
17
app/database/versions/v1.py
Normal file
17
app/database/versions/v1.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
|
||||||
|
|
||||||
|
|
||||||
|
def v1(cursor):
|
||||||
|
print("Updating to v1")
|
||||||
|
for folder_name in folder_names_and_paths.keys():
|
||||||
|
if folder_name == "custom_nodes":
|
||||||
|
continue
|
||||||
|
|
||||||
|
files = get_filename_list(folder_name)
|
||||||
|
for file in files:
|
||||||
|
file_path = get_full_path(folder_name, file)
|
||||||
|
file_without_extension = file.rsplit(".", maxsplit=1)[0]
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
|
||||||
|
(file_path, file_without_extension, folder_name),
|
||||||
|
)
|
41
app/database/versions/v1.sql
Normal file
41
app/database/versions/v1.sql
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS
|
||||||
|
models (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
architecture TEXT,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
hash TEXT,
|
||||||
|
source_url TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS
|
||||||
|
tags (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL UNIQUE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS
|
||||||
|
model_tags (
|
||||||
|
model_id INTEGER NOT NULL,
|
||||||
|
tag_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (model_id, tag_id),
|
||||||
|
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO
|
||||||
|
tags (name)
|
||||||
|
VALUES
|
||||||
|
('character'),
|
||||||
|
('style'),
|
||||||
|
('concept'),
|
||||||
|
('clothing'),
|
||||||
|
('poses'),
|
||||||
|
('background'),
|
||||||
|
('vehicle'),
|
||||||
|
('buildings'),
|
||||||
|
('objects'),
|
||||||
|
('animal'),
|
||||||
|
('action');
|
63
app/model_hasher.py
Normal file
63
app/model_hasher.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
|
class ModelHasher:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._thread = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._model_entity = None
|
||||||
|
|
||||||
|
def start(self, model_entity):
|
||||||
|
if args.disable_model_hashing:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._model_entity = model_entity
|
||||||
|
|
||||||
|
if self._thread is None:
|
||||||
|
# Lock to prevent multiple threads from starting
|
||||||
|
with self._lock:
|
||||||
|
if self._thread is None:
|
||||||
|
self._thread = threading.Thread(target=self._hash_models)
|
||||||
|
self._thread.daemon = True
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _get_models(self):
|
||||||
|
models = self._model_entity.get("WHERE hash IS NULL")
|
||||||
|
return models
|
||||||
|
|
||||||
|
def _hash_model(self, model_path):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
b = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(b)
|
||||||
|
with open(model_path, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
h.update(mv[:n])
|
||||||
|
hash = h.hexdigest()
|
||||||
|
return hash
|
||||||
|
|
||||||
|
def _hash_models(self):
|
||||||
|
while True:
|
||||||
|
models = self._get_models()
|
||||||
|
|
||||||
|
if len(models) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
time.sleep(0)
|
||||||
|
now = time.time()
|
||||||
|
logging.info(f"Hashing model {model['path']}")
|
||||||
|
hash = self._hash_model(model["path"])
|
||||||
|
logging.info(
|
||||||
|
f"Hashed model {model['path']} in {time.time() - now} seconds"
|
||||||
|
)
|
||||||
|
self._model_entity.update((model["id"],), {"hash": hash})
|
||||||
|
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
|
||||||
|
model_hasher = ModelHasher()
|
@@ -143,9 +143,13 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
|
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
|
||||||
|
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--front-end-version",
|
"--front-end-version",
|
||||||
type=str,
|
type=str,
|
||||||
|
17
server.py
17
server.py
@@ -34,6 +34,9 @@ from app.model_manager import ModelFileManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
from app.database.entities import get_entity, init_entities
|
||||||
|
from app.database.db import db
|
||||||
|
from app.model_hasher import model_hasher
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
@@ -682,11 +685,25 @@ class PromptServer():
|
|||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
||||||
|
def init_db(self, routes):
|
||||||
|
init_entities(routes)
|
||||||
|
models = get_entity("models")
|
||||||
|
|
||||||
|
if db.exists:
|
||||||
|
model_hasher.start(models)
|
||||||
|
else:
|
||||||
|
def on_db_update(_, __):
|
||||||
|
model_hasher.start(models)
|
||||||
|
|
||||||
|
db.register_after_update_callback(on_db_update)
|
||||||
|
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
self.model_file_manager.add_routes(self.routes)
|
self.model_file_manager.add_routes(self.routes)
|
||||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
self.init_db(self.routes)
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
# This is very useful for frontend dev server, which need to forward
|
# This is very useful for frontend dev server, which need to forward
|
||||||
|
513
tests-unit/app_test/entities_test.py
Normal file
513
tests-unit/app_test/entities_test.py
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
args.memory_database = True # force in-memory database for testing
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
from aiohttp import web
|
||||||
|
from app.database.entities import (
|
||||||
|
DeleteEntity,
|
||||||
|
column,
|
||||||
|
table,
|
||||||
|
Column,
|
||||||
|
GetEntity,
|
||||||
|
GetEntityById,
|
||||||
|
CreateEntity,
|
||||||
|
UpsertEntity,
|
||||||
|
UpdateEntity,
|
||||||
|
)
|
||||||
|
from app.database.db import db
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def create_table(entity):
|
||||||
|
# reset db
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
cols: list[Column] = entity._columns
|
||||||
|
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||||||
|
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
|
||||||
|
for col_name, col in cols.items():
|
||||||
|
type = None
|
||||||
|
if col.type == int:
|
||||||
|
type = "INTEGER"
|
||||||
|
elif col.type == str:
|
||||||
|
type = "TEXT"
|
||||||
|
|
||||||
|
sql += f"{col_name} {type}"
|
||||||
|
if col.required:
|
||||||
|
sql += " NOT NULL"
|
||||||
|
sql += ", "
|
||||||
|
|
||||||
|
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
|
||||||
|
sql += ")"
|
||||||
|
db.execute(sql)
|
||||||
|
|
||||||
|
|
||||||
|
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||||||
|
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||||||
|
response = await method()
|
||||||
|
assert mock.call_count == 1
|
||||||
|
assert mock.call_args[0][0] == expected_sql
|
||||||
|
assert mock.call_args[0][1:] == expected_args
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def getable_entity():
|
||||||
|
@table("getable_entity")
|
||||||
|
class GetableEntity(GetEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
nullable: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return GetableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def getable_by_id_entity():
|
||||||
|
@table("getable_by_id_entity")
|
||||||
|
class GetableByIdEntity(GetEntityById):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
|
||||||
|
return GetableByIdEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def getable_by_id_composite_entity():
|
||||||
|
@table("getable_by_id_composite_entity")
|
||||||
|
class GetableByIdCompositeEntity(GetEntityById):
|
||||||
|
id1: str = column(str, required=True, key=True)
|
||||||
|
id2: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
|
||||||
|
return GetableByIdCompositeEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def creatable_entity():
|
||||||
|
@table("creatable_entity")
|
||||||
|
class CreatableEntity(CreateEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
reqd: str = column(str, required=True)
|
||||||
|
nullable: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return CreatableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def upsertable_entity():
|
||||||
|
@table("upsertable_entity")
|
||||||
|
class UpsertableEntity(UpsertEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
reqd: str = column(str, required=True)
|
||||||
|
nullable: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return UpsertableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def updateable_entity():
|
||||||
|
@table("updateable_entity")
|
||||||
|
class UpdateableEntity(UpdateEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
reqd: str = column(str, required=True)
|
||||||
|
|
||||||
|
return UpdateableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def deletable_entity():
|
||||||
|
@table("deletable_entity")
|
||||||
|
class DeletableEntity(DeleteEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
|
||||||
|
return DeletableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def deletable_composite_entity():
|
||||||
|
@table("deletable_composite_entity")
|
||||||
|
class DeletableCompositeEntity(DeleteEntity):
|
||||||
|
id1: str = column(str, required=True, key=True)
|
||||||
|
id2: int = column(int, required=True, key=True)
|
||||||
|
|
||||||
|
return DeletableCompositeEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def entity(request):
|
||||||
|
value = request.getfixturevalue(request.param)
|
||||||
|
create_table(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(aiohttp_client, app):
|
||||||
|
return await aiohttp_client(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(entity):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
entity.register_route(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_empty_response(client):
|
||||||
|
expected_sql = "SELECT * FROM getable_entity"
|
||||||
|
expected_args = ()
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_with_data(client):
|
||||||
|
# seed db
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sql = "SELECT * FROM getable_entity"
|
||||||
|
expected_args = ()
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id": 1, "test": "test1", "nullable": None},
|
||||||
|
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_with_top_parameter(client):
|
||||||
|
# seed with 3 rows
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
|
||||||
|
expected_args = ()
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_entity?top=2"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id": 1, "test": "test1", "nullable": None},
|
||||||
|
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_with_invalid_top_parameter(client):
|
||||||
|
response = await client.get("/db/getable_entity?top=hello")
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid top parameter",
|
||||||
|
"field": "top",
|
||||||
|
"value": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_empty_response(client):
|
||||||
|
# seed db
|
||||||
|
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
|
||||||
|
|
||||||
|
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
|
||||||
|
expected_args = (1,)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_by_id_entity/1"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id": 1, "test": "test1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_with_invalid_id(client):
|
||||||
|
response = await client.get("/db/getable_by_id_entity/hello")
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id",
|
||||||
|
"value": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_composite(client):
|
||||||
|
# seed db
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sql = (
|
||||||
|
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||||
|
)
|
||||||
|
expected_args = ("one", 2)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id1": "one", "id2": 2, "test": "test"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_composite_with_invalid_id(client):
|
||||||
|
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id2",
|
||||||
|
"value": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model(client):
|
||||||
|
expected_sql = (
|
||||||
|
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
|
||||||
|
)
|
||||||
|
expected_args = (1, "test1", "reqd1")
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.post(
|
||||||
|
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||||
|
),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"test": "test1",
|
||||||
|
"reqd": "reqd1",
|
||||||
|
"nullable": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_missing_required_field(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity", json={"id": 1, "test": "test1"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Missing field",
|
||||||
|
"field": "reqd",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_missing_key_field(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Missing field",
|
||||||
|
"field": "id",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_key_data(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={
|
||||||
|
"id": "not_an_integer",
|
||||||
|
"test": "test1",
|
||||||
|
"reqd": "reqd1",
|
||||||
|
}, # id should be int
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id",
|
||||||
|
"value": "not_an_integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_field_data(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id",
|
||||||
|
"value": "aaa",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_field_type(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={
|
||||||
|
"id": 1,
|
||||||
|
"test": ["invalid_array"],
|
||||||
|
"reqd": "reqd1",
|
||||||
|
}, # test should be string
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "test",
|
||||||
|
"value": ["invalid_array"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_field_name(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Unknown field",
|
||||||
|
"field": "nonexistent_field",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
|
||||||
|
async def test_upsert_model(client):
|
||||||
|
expected_sql = (
|
||||||
|
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
|
||||||
|
"RETURNING *"
|
||||||
|
)
|
||||||
|
expected_args = (1, "test1", "reqd1")
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.put(
|
||||||
|
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||||
|
),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"test": "test1",
|
||||||
|
"reqd": "reqd1",
|
||||||
|
"nullable": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model(client):
|
||||||
|
# seed db
|
||||||
|
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
|
||||||
|
|
||||||
|
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
|
||||||
|
expected_args = ("updated_test", 1)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"reqd": "updated_test",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model_reject_null_required_field(client):
|
||||||
|
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Required field",
|
||||||
|
"field": "reqd",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model_reject_invalid_field(client):
|
||||||
|
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Unknown field",
|
||||||
|
"field": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model_reject_missing_record(client):
|
||||||
|
response = await client.patch(
|
||||||
|
"/db/updateable_entity/1", json={"reqd": "updated_test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 404
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Failed to update entity",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
|
||||||
|
async def test_delete_model(client):
|
||||||
|
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
|
||||||
|
expected_args = (1,)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.delete("/db/deletable_entity/1"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 204
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
|
||||||
|
async def test_delete_model_composite_key(client):
|
||||||
|
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||||
|
expected_args = ("one", 2)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.delete("/db/deletable_composite_entity/one/2"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 204
|
Reference in New Issue
Block a user