1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Compare commits

...

3 Commits

Author SHA1 Message Date
pythongosssss
01110de8a3 Add tests for delete & update 2025-02-21 17:54:14 +00:00
pythongosssss
785a220757 refactor, adding tests 2025-02-16 17:22:48 +00:00
pythongosssss
b6b475191d Add sqlite db 2025-01-30 21:48:53 +00:00
11 changed files with 1235 additions and 0 deletions

0
app/database/__init__.py Normal file
View File

126
app/database/db.py Normal file
View File

@@ -0,0 +1,126 @@
import logging
import os
import sqlite3
from contextlib import contextmanager
from queue import Queue, Empty, Full
import threading
from app.database.updater import DatabaseUpdater
import folder_paths
from comfy.cli_args import args
class Database:
def __init__(self, database_path=None, pool_size=1):
if database_path is None:
self.exists = False
database_path = "file::memory:?cache=shared"
else:
self.exists = os.path.exists(database_path)
self.database_path = database_path
self.pool_size = pool_size
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
self.connection_pool: Queue = Queue(maxsize=pool_size)
self._db_lock = threading.Lock()
self._initialized = False
self._closing = False
self._after_update_callbacks = []
def _setup(self):
if self._initialized:
return
with self._db_lock:
if not self._initialized:
self._make_db()
self._initialized = True
def _create_connection(self):
# TODO: Catch error for sqlite lib missing on linux
logging.info(f"Creating connection to {self.database_path}")
conn = sqlite3.connect(
self.database_path,
check_same_thread=False,
uri=self.database_path.startswith("file::"),
)
conn.execute("PRAGMA foreign_keys = ON")
self.exists = True
logging.info(f"Connected!")
return conn
def _make_db(self):
with self._get_connection() as con:
updater = DatabaseUpdater(con, self.database_path)
result = updater.update()
if result is not None:
old_version, new_version = result
for callback in self._after_update_callbacks:
callback(old_version, new_version)
def _transform(self, row, columns):
return {col.name: value for value, col in zip(row, columns)}
@contextmanager
def _get_connection(self):
if self._closing:
raise Exception("Database is shutting down")
try:
# Try to get connection from pool
connection = self.connection_pool.get_nowait()
except Empty:
# Create new connection if pool is empty
connection = self._create_connection()
try:
yield connection
finally:
try:
# Try to add to pool if it's empty
self.connection_pool.put_nowait(connection)
except Full:
# Pool is full, close the connection
connection.close()
@contextmanager
def get_connection(self):
# Setup the database if it's not already initialized
self._setup()
with self._get_connection() as connection:
yield connection
def execute(self, sql, *args):
with self.get_connection() as connection:
cursor = connection.execute(sql, args)
results = cursor.fetchall()
return results
def register_after_update_callback(self, callback):
self._after_update_callbacks.append(callback)
def close(self):
if self._closing:
return
# Drain and close all connections in the pool
self._closing = True
while True:
try:
conn = self.connection_pool.get_nowait()
conn.close()
except Empty:
break
self._closing = False
def __del__(self):
try:
self.close()
except:
pass
# Create a global instance
db_path = None
if not args.memory_database:
db_path = folder_paths.get_user_directory() + "/comfyui.db"
db = Database(db_path)

343
app/database/entities.py Normal file
View File

@@ -0,0 +1,343 @@
from typing import Optional, Any, Callable
from dataclasses import dataclass
from functools import wraps
from aiohttp import web
from app.database.db import db
primitives = (bool, str, int, float, type(None))
def is_primitive(obj):
return isinstance(obj, primitives)
class EntityError(Exception):
def __init__(
self, message: str, field: str = None, value: Any = None, status_code: int = 400
):
self.message = message
self.field = field
self.value = value
self.status_code = status_code
super().__init__(self.message)
def to_json(self):
result = {"message": self.message}
if self.field is not None:
result["field"] = self.field
if self.value is not None:
result["value"] = self.value
return result
def __str__(self) -> str:
return f"{self.message} {self.field} {self.value}"
class EntityCommon(dict):
@classmethod
def _get_route(cls, include_key: bool):
route = f"/db/{cls._table_name}"
if include_key:
route += "".join([f"/{{{k}}}" for k in cls._key_columns])
return route
@classmethod
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
route = cls._get_route(include_key)
@getattr(routes, verb)(route)
async def _(request):
try:
data = await handler(request)
if data is None:
return web.json_response(status=204)
return web.json_response(data)
except EntityError as e:
return web.json_response(e.to_json(), status=e.status_code)
@classmethod
def _transform(cls, row: list[Any]):
return {col: value for col, value in zip(cls._columns, row)}
@classmethod
def _transform_rows(cls, rows: list[list[Any]]):
return [cls._transform(row) for row in rows]
@classmethod
def _extract_key(cls, request):
return {key: request.match_info.get(key, None) for key in cls._key_columns}
@classmethod
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
result = {}
if not isinstance(data, dict):
raise EntityError("Invalid data")
# Ensure all required fields are present
for field in data:
if field not in fields:
raise EntityError("Unknown field", field)
for key in fields:
col = cls._columns[key]
if key not in data:
if col.required and not allow_missing:
raise EntityError("Missing field", key)
else:
# e.g. for updates, we allow missing fields
continue
elif data[key] is None and col.required:
# Dont allow None for required fields
raise EntityError("Required field", key)
# Validate data type
value = data[key]
if value is not None and not is_primitive(value):
raise EntityError("Invalid value", key, value)
try:
type = col.type
if value is not None and not isinstance(value, type):
value = type(value)
result[key] = value
except Exception:
raise EntityError("Invalid value", key, value)
return result
@classmethod
def _validate_id(cls, id: dict):
return cls._validate(cls._key_columns, id)
@classmethod
def _validate_data(cls, data: dict, allow_missing: bool = False):
return cls._validate(cls._columns.keys(), data, allow_missing)
def __setattr__(self, name, value):
if name in self._columns:
self[name] = value
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self:
return self[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class GetEntity(EntityCommon):
@classmethod
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
limit = ""
if top is not None and isinstance(top, int):
limit = f" LIMIT {top}"
result = db.execute(
f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}",
)
# Map each row in result to an instance of the class
return cls._transform_rows(result)
@classmethod
def register_route(cls, routes):
async def get_handler(request):
top = request.rel_url.query.get("top", None)
if top is not None:
try:
top = int(top)
except Exception:
raise EntityError("Invalid top parameter", "top", top)
return cls.get(top)
cls._register_route(routes, "get", False, get_handler)
class GetEntityById(EntityCommon):
@classmethod
def get_by_id(cls, id: dict):
id = cls._validate_id(id)
result = db.execute(
f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}",
*[id[key] for key in cls._key_columns],
)
return cls._transform_rows(result)
@classmethod
def register_route(cls, routes):
async def get_by_id_handler(request):
id = cls._extract_key(request)
return cls.get_by_id(id)
cls._register_route(routes, "get", True, get_by_id_handler)
class CreateEntity(EntityCommon):
@classmethod
def create(cls, data: dict, allow_upsert: bool = False):
data = cls._validate_data(data)
values = ", ".join(["?"] * len(data))
on_conflict = ""
data_keys = ", ".join(list(data.keys()))
if allow_upsert:
# Remove key columns from data
upsert_keys = [key for key in data if key not in cls._key_columns]
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}"
sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
result = db.execute(
sql,
*[data[key] for key in data],
)
if len(result) == 0:
raise EntityError("Failed to create entity", status_code=500)
return cls._transform_rows(result)[0]
@classmethod
def register_route(cls, routes):
async def create_handler(request):
data = await request.json()
return cls.create(data)
cls._register_route(routes, "post", False, create_handler)
class UpdateEntity(EntityCommon):
@classmethod
def update(cls, id: list, data: dict):
id = cls._validate_id(id)
data = cls._validate_data(data, allow_missing=True)
sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *"
result = db.execute(
sql,
*[data[key] for key in data],
*[id[key] for key in cls._key_columns],
)
if len(result) == 0:
raise EntityError("Failed to update entity", status_code=404)
return cls._transform_rows(result)[0]
@classmethod
def register_route(cls, routes):
async def update_handler(request):
id = cls._extract_key(request)
data = await request.json()
return cls.update(id, data)
cls._register_route(routes, "patch", True, update_handler)
class UpsertEntity(CreateEntity):
@classmethod
def upsert(cls, data: dict):
return cls.create(data, allow_upsert=True)
@classmethod
def register_route(cls, routes):
async def upsert_handler(request):
data = await request.json()
return cls.upsert(data)
cls._register_route(routes, "put", False, upsert_handler)
class DeleteEntity(EntityCommon):
@classmethod
def delete(cls, id: list):
id = cls._validate_id(id)
db.execute(
f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}",
*[id[key] for key in cls._key_columns],
)
@classmethod
def register_route(cls, routes):
async def delete_handler(request):
id = cls._extract_key(request)
cls.delete(id)
cls._register_route(routes, "delete", True, delete_handler)
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
pass
@dataclass
class Column:
type: Any
required: bool = False
key: bool = False
default: Any = None
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
return Column(type_, required, key, default)
def table(table_name: str):
def decorator(cls):
# Store table name
cls._table_name = table_name
# Process column definitions
columns: dict[str, Column] = {}
for attr_name, attr_value in cls.__dict__.items():
if isinstance(attr_value, Column):
columns[attr_name] = attr_value
# Store columns metadata
cls._columns = columns
cls._key_columns = [col for col in columns if columns[col].key]
cls._column_csv = ", ".join([col for col in columns])
cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns])
# Add initialization
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
# Initialize columns with default values
for col_name, col_def in cls._columns.items():
setattr(self, col_name, col_def.default)
# Call original init
original_init(self, *args, **kwargs)
cls.__init__ = new_init
return cls
return decorator
def test():
@table("models")
class Model(BaseEntity):
id: int = column(int, required=True, key=True)
path: str = column(str, required=True)
name: str = column(str, required=True)
description: Optional[str] = column(str)
architecture: Optional[str] = column(str)
type: str = column(str, required=True)
hash: Optional[str] = column(str)
source_url: Optional[str] = column(str)
return Model
@table("test")
class Test(GetEntity, CreateEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
Model = test()

32
app/database/routes.py Normal file
View File

@@ -0,0 +1,32 @@
from app.database.db import db
from aiohttp import web
def create_routes(
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
):
if get:
@routes.get(f"/{prefix}/{table}")
async def get_table(request):
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM {table}")
rows = cursor.fetchall()
return web.json_response(rows)
if get_by_id:
@routes.get(f"/{prefix}/{table}/{id}")
async def get_table_by_id(request):
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
row = cursor.fetchone()
return web.json_response(row)
if post:
@routes.post(f"/{prefix}/{table}")
async def post_table(request):
data = await request.json()
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
return web.json_response({"status": "success"})

79
app/database/updater.py Normal file
View File

@@ -0,0 +1,79 @@
import logging
import os
import sqlite3
from app.database.versions.v1 import v1
class DatabaseUpdater:
def __init__(self, connection, database_path):
self.connection = connection
self.database_path = database_path
self.current_version = self.get_db_version()
self.version_updates = {
1: v1,
}
self.max_version = max(self.version_updates.keys())
self.update_required = self.current_version < self.max_version
logging.info(f"Database version: {self.current_version}")
def get_db_version(self):
return self.connection.execute("PRAGMA user_version").fetchone()[0]
def backup(self):
bkp_path = self.database_path + ".bkp"
if os.path.exists(bkp_path):
# TODO: auto-rollback failed upgrades
raise Exception(
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
)
bkp = sqlite3.connect(bkp_path)
self.connection.backup(bkp)
bkp.close()
logging.info("Database backup taken pre-upgrade.")
return bkp_path
def update(self):
if not self.update_required:
return None
bkp_version = self.current_version
bkp_path = None
if self.current_version > 0:
bkp_path = self.backup()
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
dirname = os.path.dirname(__file__)
cursor = self.connection.cursor()
for version in range(self.current_version + 1, self.max_version + 1):
filename = os.path.join(dirname, f"versions/v{version}.sql")
if not os.path.exists(filename):
raise Exception(
f"Database update script for version {version} not found"
)
try:
with open(filename, "r") as file:
sql = file.read()
cursor.executescript(sql)
except Exception as e:
raise Exception(
f"Failed to execute update script for version {version}: {e}"
)
method = self.version_updates[version]
if method is not None:
method(cursor)
cursor.execute("PRAGMA user_version = %d" % self.max_version)
self.connection.commit()
cursor.close()
self.current_version = self.get_db_version()
if bkp_path:
# Keep a copy of the backup in case something goes wrong and we need to rollback
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
logging.info(f"Upgrade to successful.")
return (bkp_version, self.current_version)

View File

@@ -0,0 +1,17 @@
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
def v1(cursor):
print("Updating to v1")
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes":
continue
files = get_filename_list(folder_name)
for file in files:
file_path = get_full_path(folder_name, file)
file_without_extension = file.rsplit(".", maxsplit=1)[0]
cursor.execute(
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
(file_path, file_without_extension, folder_name),
)

View File

@@ -0,0 +1,41 @@
CREATE TABLE IF NOT EXISTS
models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
architecture TEXT,
type TEXT NOT NULL,
hash TEXT,
source_url TEXT
);
CREATE TABLE IF NOT EXISTS
tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS
model_tags (
model_id INTEGER NOT NULL,
tag_id INTEGER NOT NULL,
PRIMARY KEY (model_id, tag_id),
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
);
INSERT INTO
tags (name)
VALUES
('character'),
('style'),
('concept'),
('clothing'),
('poses'),
('background'),
('vehicle'),
('buildings'),
('objects'),
('animal'),
('action');

63
app/model_hasher.py Normal file
View File

@@ -0,0 +1,63 @@
import hashlib
import logging
import threading
import time
from comfy.cli_args import args
class ModelHasher:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._model_entity = None
def start(self, model_entity):
if args.disable_model_hashing:
return
self._model_entity = model_entity
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
if self._thread is None:
self._thread = threading.Thread(target=self._hash_models)
self._thread.daemon = True
self._thread.start()
def _get_models(self):
models = self._model_entity.get("WHERE hash IS NULL")
return models
def _hash_model(self, model_path):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(model_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
hash = h.hexdigest()
return hash
def _hash_models(self):
while True:
models = self._get_models()
if len(models) == 0:
break
for model in models:
time.sleep(0)
now = time.time()
logging.info(f"Hashing model {model['path']}")
hash = self._hash_model(model["path"])
logging.info(
f"Hashed model {model['path']} in {time.time() - now} seconds"
)
self._model_entity.update((model["id"],), {"hash": hash})
self._thread = None
model_hasher = ModelHasher()

View File

@@ -143,9 +143,13 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
parser.add_argument(
"--front-end-version",
type=str,

View File

@@ -34,6 +34,9 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
from app.database.entities import get_entity, init_entities
from app.database.db import db
from app.model_hasher import model_hasher
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@@ -682,11 +685,25 @@ class PromptServer():
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def init_db(self, routes):
init_entities(routes)
models = get_entity("models")
if db.exists:
model_hasher.start(models)
else:
def on_db_update(_, __):
model_hasher.start(models)
db.register_after_update_callback(on_db_update)
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())
self.init_db(self.routes)
# Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward

View File

@@ -0,0 +1,513 @@
from comfy.cli_args import args
args.memory_database = True # force in-memory database for testing
from typing import Callable, Optional
import pytest
import pytest_asyncio
from unittest.mock import patch
from aiohttp import web
from app.database.entities import (
DeleteEntity,
column,
table,
Column,
GetEntity,
GetEntityById,
CreateEntity,
UpsertEntity,
UpdateEntity,
)
from app.database.db import db
pytestmark = pytest.mark.asyncio
def create_table(entity):
# reset db
db.close()
cols: list[Column] = entity._columns
# Create tables as temporary so when we close the db, the tables are dropped for next test
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
for col_name, col in cols.items():
type = None
if col.type == int:
type = "INTEGER"
elif col.type == str:
type = "TEXT"
sql += f"{col_name} {type}"
if col.required:
sql += " NOT NULL"
sql += ", "
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
sql += ")"
db.execute(sql)
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
with patch.object(db, "execute", wraps=db.execute) as mock:
response = await method()
assert mock.call_count == 1
assert mock.call_args[0][0] == expected_sql
assert mock.call_args[0][1:] == expected_args
return response
@pytest.fixture
def getable_entity():
@table("getable_entity")
class GetableEntity(GetEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
nullable: Optional[str] = column(str)
return GetableEntity
@pytest.fixture
def getable_by_id_entity():
@table("getable_by_id_entity")
class GetableByIdEntity(GetEntityById):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
return GetableByIdEntity
@pytest.fixture
def getable_by_id_composite_entity():
@table("getable_by_id_composite_entity")
class GetableByIdCompositeEntity(GetEntityById):
id1: str = column(str, required=True, key=True)
id2: int = column(int, required=True, key=True)
test: str = column(str, required=True)
return GetableByIdCompositeEntity
@pytest.fixture
def creatable_entity():
@table("creatable_entity")
class CreatableEntity(CreateEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
reqd: str = column(str, required=True)
nullable: Optional[str] = column(str)
return CreatableEntity
@pytest.fixture
def upsertable_entity():
@table("upsertable_entity")
class UpsertableEntity(UpsertEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
reqd: str = column(str, required=True)
nullable: Optional[str] = column(str)
return UpsertableEntity
@pytest.fixture
def updateable_entity():
@table("updateable_entity")
class UpdateableEntity(UpdateEntity):
id: int = column(int, required=True, key=True)
reqd: str = column(str, required=True)
return UpdateableEntity
@pytest.fixture
def deletable_entity():
@table("deletable_entity")
class DeletableEntity(DeleteEntity):
id: int = column(int, required=True, key=True)
return DeletableEntity
@pytest.fixture
def deletable_composite_entity():
@table("deletable_composite_entity")
class DeletableCompositeEntity(DeleteEntity):
id1: str = column(str, required=True, key=True)
id2: int = column(int, required=True, key=True)
return DeletableCompositeEntity
@pytest.fixture()
def entity(request):
value = request.getfixturevalue(request.param)
create_table(value)
return value
@pytest_asyncio.fixture
async def client(aiohttp_client, app):
return await aiohttp_client(app)
@pytest.fixture
def app(entity):
app = web.Application()
routes = web.RouteTableDef()
entity.register_route(routes)
app.add_routes(routes)
return app
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_empty_response(client):
expected_sql = "SELECT * FROM getable_entity"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
)
assert response.status == 200
assert await response.json() == []
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_data(client):
# seed db
db.execute(
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
)
expected_sql = "SELECT * FROM getable_entity"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1", "nullable": None},
{"id": 2, "test": "test2", "nullable": "test2"},
]
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_top_parameter(client):
# seed with 3 rows
db.execute(
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
)
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity?top=2"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1", "nullable": None},
{"id": 2, "test": "test2", "nullable": "test2"},
]
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_invalid_top_parameter(client):
response = await client.get("/db/getable_entity?top=hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid top parameter",
"field": "top",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
async def test_get_model_by_id_empty_response(client):
# seed db
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
expected_args = (1,)
response = await wrap_db(
lambda: client.get("/db/getable_by_id_entity/1"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1"},
]
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
async def test_get_model_by_id_with_invalid_id(client):
response = await client.get("/db/getable_by_id_entity/hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
async def test_get_model_by_id_composite(client):
# seed db
db.execute(
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
)
expected_sql = (
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
)
expected_args = ("one", 2)
response = await wrap_db(
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id1": "one", "id2": 2, "test": "test"},
]
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
async def test_get_model_by_id_composite_with_invalid_id(client):
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id2",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model(client):
expected_sql = (
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
)
expected_args = (1, "test1", "reqd1")
response = await wrap_db(
lambda: client.post(
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"test": "test1",
"reqd": "reqd1",
"nullable": None,
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_missing_required_field(client):
response = await client.post(
"/db/creatable_entity", json={"id": 1, "test": "test1"}
)
assert response.status == 400
assert await response.json() == {
"message": "Missing field",
"field": "reqd",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_missing_key_field(client):
response = await client.post(
"/db/creatable_entity",
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
)
assert response.status == 400
assert await response.json() == {
"message": "Missing field",
"field": "id",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_key_data(client):
response = await client.post(
"/db/creatable_entity",
json={
"id": "not_an_integer",
"test": "test1",
"reqd": "reqd1",
}, # id should be int
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "not_an_integer",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_data(client):
response = await client.post(
"/db/creatable_entity",
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "aaa",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_type(client):
response = await client.post(
"/db/creatable_entity",
json={
"id": 1,
"test": ["invalid_array"],
"reqd": "reqd1",
}, # test should be string
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "test",
"value": ["invalid_array"],
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_name(client):
response = await client.post(
"/db/creatable_entity",
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
)
assert response.status == 400
assert await response.json() == {
"message": "Unknown field",
"field": "nonexistent_field",
}
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
async def test_upsert_model(client):
expected_sql = (
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
"RETURNING *"
)
expected_args = (1, "test1", "reqd1")
response = await wrap_db(
lambda: client.put(
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"test": "test1",
"reqd": "reqd1",
"nullable": None,
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model(client):
# seed db
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
expected_args = ("updated_test", 1)
response = await wrap_db(
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"reqd": "updated_test",
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model_reject_null_required_field(client):
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
assert response.status == 400
assert await response.json() == {
"message": "Required field",
"field": "reqd",
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model_reject_invalid_field(client):
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
assert response.status == 400
assert await response.json() == {
"message": "Unknown field",
"field": "hello",
}
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
async def test_update_model_reject_missing_record(client):
response = await client.patch(
"/db/updateable_entity/1", json={"reqd": "updated_test"}
)
assert response.status == 404
assert await response.json() == {
"message": "Failed to update entity",
}
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
async def test_delete_model(client):
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
expected_args = (1,)
response = await wrap_db(
lambda: client.delete("/db/deletable_entity/1"),
expected_sql,
expected_args,
)
assert response.status == 204
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
async def test_delete_model_composite_key(client):
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
expected_args = ("one", 2)
response = await wrap_db(
lambda: client.delete("/db/deletable_composite_entity/one/2"),
expected_sql,
expected_args,
)
assert response.status == 204