1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

refactor, adding tests

This commit is contained in:
pythongosssss
2025-02-16 17:22:48 +00:00
parent b6b475191d
commit 785a220757
11 changed files with 916 additions and 13 deletions

0
app/database/__init__.py Normal file
View File

View File

@@ -1,9 +1,10 @@
import logging
import os
import sqlite3
from contextlib import contextmanager
from queue import Queue, Empty, Full
import threading
from app.database_updater import DatabaseUpdater
from app.database.updater import DatabaseUpdater
import folder_paths
from comfy.cli_args import args
@@ -11,7 +12,10 @@ from comfy.cli_args import args
class Database:
def __init__(self, database_path=None, pool_size=1):
if database_path is None:
self.exists = False
database_path = "file::memory:?cache=shared"
else:
self.exists = os.path.exists(database_path)
self.database_path = database_path
self.pool_size = pool_size
@@ -20,6 +24,7 @@ class Database:
self._db_lock = threading.Lock()
self._initialized = False
self._closing = False
self._after_update_callbacks = []
def _setup(self):
if self._initialized:
@@ -33,14 +38,28 @@ class Database:
def _create_connection(self):
# TODO: Catch error for sqlite lib missing on linux
logging.info(f"Creating connection to {self.database_path}")
conn = sqlite3.connect(self.database_path, check_same_thread=False)
conn = sqlite3.connect(
self.database_path,
check_same_thread=False,
uri=self.database_path.startswith("file::"),
)
conn.execute("PRAGMA foreign_keys = ON")
self.exists = True
logging.info(f"Connected!")
return conn
def _make_db(self):
with self._get_connection() as con:
updater = DatabaseUpdater(con)
updater.update()
updater = DatabaseUpdater(con, self.database_path)
result = updater.update()
if result is not None:
old_version, new_version = result
for callback in self._after_update_callbacks:
callback(old_version, new_version)
def _transform(self, row, columns):
return {col.name: value for value, col in zip(row, columns)}
@contextmanager
def _get_connection(self):
@@ -71,6 +90,15 @@ class Database:
with self._get_connection() as connection:
yield connection
def execute(self, sql, *args):
with self.get_connection() as connection:
cursor = connection.execute(sql, args)
results = cursor.fetchall()
return results
def register_after_update_callback(self, callback):
self._after_update_callbacks.append(callback)
def close(self):
if self._closing:
return
@@ -82,6 +110,7 @@ class Database:
conn.close()
except Empty:
break
self._closing = False
def __del__(self):
try:

301
app/database/entities.py Normal file
View File

@@ -0,0 +1,301 @@
from typing import Optional, Any, Callable
from dataclasses import dataclass
from functools import wraps
from aiohttp import web
from app.database.db import db
primitives = (bool, str, int, float, type(None))
def is_primitive(obj):
return isinstance(obj, primitives)
class ValidationError(Exception):
def __init__(self, message: str, field: str = None, value: Any = None):
self.message = message
self.field = field
self.value = value
super().__init__(self.message)
def to_json(self):
result = {"message": self.message}
if self.field is not None:
result["field"] = self.field
if self.value is not None:
result["value"] = self.value
return result
def __str__(self) -> str:
return f"{self.message} {self.field} {self.value}"
class EntityCommon(dict):
@classmethod
def _get_route(cls, include_key: bool):
route = f"/db/{cls.__table_name__}"
if include_key:
route += "".join([f"/{{{k}}}" for k in cls.__key_columns__])
return route
@classmethod
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
route = cls._get_route(include_key)
@getattr(routes, verb)(route)
async def _(request):
try:
data = await handler(request)
return web.json_response(data)
except ValidationError as e:
return web.json_response(e.to_json(), status=400)
@classmethod
def _transform(cls, row: list[Any]):
return {col: value for col, value in zip(cls.__columns__, row)}
@classmethod
def _transform_rows(cls, rows: list[list[Any]]):
return [cls._transform(row) for row in rows]
@classmethod
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
result = {}
if not isinstance(data, dict):
raise ValidationError("Invalid data")
# Ensure all required fields are present
for field in data:
if field not in fields:
raise ValidationError("Unknown field", field)
for key in fields:
col = cls.__columns__[key]
if key not in data:
if col.required and not allow_missing:
raise ValidationError("Missing field", key)
else:
# e.g. for updates, we allow missing fields
continue
elif data[key] is None and col.required:
# Dont allow None for required fields
raise ValidationError("Required field", key)
# Validate data type
value = data[key]
if value is not None and not is_primitive(value):
raise ValidationError("Invalid value", key, value)
try:
type = col.type
if value is not None and not isinstance(value, type):
value = type(value)
result[key] = value
except Exception:
raise ValidationError("Invalid value", key, value)
return result
@classmethod
def _validate_id(cls, id: dict):
return cls._validate(cls.__key_columns__, id)
@classmethod
def _validate_data(cls, data: dict):
return cls._validate(cls.__columns__.keys(), data)
def __setattr__(self, name, value):
if name in self.__columns__:
self[name] = value
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self:
return self[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class GetEntity(EntityCommon):
@classmethod
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
limit = ""
if top is not None and isinstance(top, int):
limit = f" LIMIT {top}"
result = db.execute(
f"SELECT * FROM {cls.__table_name__}{limit}{f' WHERE {where}' if where else ''}",
)
# Map each row in result to an instance of the class
return cls._transform_rows(result)
@classmethod
def register_route(cls, routes):
async def get_handler(request):
top = request.rel_url.query.get("top", None)
if top is not None:
try:
top = int(top)
except Exception:
raise ValidationError("Invalid top parameter", "top", top)
return cls.get(top)
cls._register_route(routes, "get", False, get_handler)
class GetEntityById(EntityCommon):
@classmethod
def get_by_id(cls, id: dict):
id = cls._validate_id(id)
result = db.execute(
f"SELECT * FROM {cls.__table_name__} WHERE {cls.__where_clause__}",
*[id[key] for key in cls.__key_columns__],
)
return cls._transform_rows(result)
@classmethod
def register_route(cls, routes):
async def get_by_id_handler(request):
id = {key: request.match_info.get(key, None) for key in cls.__key_columns__}
return cls.get_by_id(id)
cls._register_route(routes, "get", True, get_by_id_handler)
class CreateEntity(EntityCommon):
@classmethod
def create(cls, data: dict, allow_upsert: bool = False):
data = cls._validate_data(data)
values = ", ".join(["?"] * len(data))
on_conflict = ""
data_keys = ", ".join(list(data.keys()))
if allow_upsert:
# Remove key columns from data
upsert_keys = [key for key in data if key not in cls.__key_columns__]
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
on_conflict = f" ON CONFLICT ({', '.join(cls.__key_columns__)}) DO UPDATE SET {set_clause}"
sql = f"INSERT INTO {cls.__table_name__} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
result = db.execute(
sql,
*[data[key] for key in data],
)
if len(result) == 0:
raise RuntimeError("Failed to create entity")
return cls._transform_rows(result)[0]
@classmethod
def register_route(cls, routes):
async def create_handler(request):
data = await request.json()
return cls.create(data)
cls._register_route(routes, "post", False, create_handler)
class UpdateEntity(EntityCommon):
@classmethod
def update(cls, id: list, data: dict):
pass
class UpsertEntity(CreateEntity):
@classmethod
def upsert(cls, data: dict):
return cls.create(data, allow_upsert=True)
@classmethod
def register_route(cls, routes):
async def upsert_handler(request):
data = await request.json()
return cls.upsert(data)
cls._register_route(routes, "put", False, upsert_handler)
class DeleteEntity(EntityCommon):
@classmethod
def delete(cls, id: list):
pass
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
pass
@dataclass
class Column:
type: Any
required: bool = False
key: bool = False
default: Any = None
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
return Column(type_, required, key, default)
def table(table_name: str):
def decorator(cls):
# Store table name
cls.__table_name__ = table_name
# Process column definitions
columns: dict[str, Column] = {}
for attr_name, attr_value in cls.__dict__.items():
if isinstance(attr_value, Column):
columns[attr_name] = attr_value
# Store columns metadata
cls.__columns__ = columns
cls.__key_columns__ = [col for col in columns if columns[col].key]
cls.__column_csv__ = ", ".join([col for col in columns])
cls.__where_clause__ = " AND ".join(
[f"{col} = ?" for col in cls.__key_columns__]
)
# Add initialization
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
# Initialize columns with default values
for col_name, col_def in cls.__columns__.items():
setattr(self, col_name, col_def.default)
# Call original init
original_init(self, *args, **kwargs)
cls.__init__ = new_init
return cls
return decorator
def test():
@table("models")
class Model(BaseEntity):
id: int = column(int, required=True, key=True)
path: str = column(str, required=True)
name: str = column(str, required=True)
description: Optional[str] = column(str)
architecture: Optional[str] = column(str)
type: str = column(str, required=True)
hash: Optional[str] = column(str)
source_url: Optional[str] = column(str)
return Model
@table("test")
class Test(GetEntity, CreateEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
Model = test()

32
app/database/routes.py Normal file
View File

@@ -0,0 +1,32 @@
from app.database.db import db
from aiohttp import web
def create_routes(
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
):
if get:
@routes.get(f"/{prefix}/{table}")
async def get_table(request):
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM {table}")
rows = cursor.fetchall()
return web.json_response(rows)
if get_by_id:
@routes.get(f"/{prefix}/{table}/{id}")
async def get_table_by_id(request):
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
row = cursor.fetchone()
return web.json_response(row)
if post:
@routes.post(f"/{prefix}/{table}")
async def post_table(request):
data = await request.json()
connection = db.get_connection()
cursor = connection.cursor()
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
return web.json_response({"status": "success"})

View File

@@ -1,13 +1,16 @@
import logging
import os
import sqlite3
from app.database.versions.v1 import v1
class DatabaseUpdater:
def __init__(self, connection):
def __init__(self, connection, database_path):
self.connection = connection
self.database_path = database_path
self.current_version = self.get_db_version()
self.version_updates = {
1: self._update_to_v1,
1: v1,
}
self.max_version = max(self.version_updates.keys())
self.update_required = self.current_version < self.max_version
@@ -16,16 +19,35 @@ class DatabaseUpdater:
def get_db_version(self):
return self.connection.execute("PRAGMA user_version").fetchone()[0]
def backup(self):
bkp_path = self.database_path + ".bkp"
if os.path.exists(bkp_path):
# TODO: auto-rollback failed upgrades
raise Exception(
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
)
bkp = sqlite3.connect(bkp_path)
self.connection.backup(bkp)
bkp.close()
logging.info("Database backup taken pre-upgrade.")
return bkp_path
def update(self):
if not self.update_required:
return
return None
bkp_version = self.current_version
bkp_path = None
if self.current_version > 0:
bkp_path = self.backup()
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
dirname = os.path.dirname(__file__)
cursor = self.connection.cursor()
for version in range(self.current_version + 1, self.max_version + 1):
filename = os.path.join(dirname, f"db/v{version}.sql")
filename = os.path.join(dirname, f"versions/v{version}.sql")
if not os.path.exists(filename):
raise Exception(
f"Database update script for version {version} not found"
@@ -49,6 +71,9 @@ class DatabaseUpdater:
cursor.close()
self.current_version = self.get_db_version()
def _update_to_v1(self, cursor):
# TODO: migrate users and settings
print("Updating to v1")
if bkp_path:
# Keep a copy of the backup in case something goes wrong and we need to rollback
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
logging.info(f"Upgrade to successful.")
return (bkp_version, self.current_version)

View File

@@ -0,0 +1,17 @@
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
def v1(cursor):
print("Updating to v1")
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes":
continue
files = get_filename_list(folder_name)
for file in files:
file_path = get_full_path(folder_name, file)
file_without_extension = file.rsplit(".", maxsplit=1)[0]
cursor.execute(
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
(file_path, file_without_extension, folder_name),
)

View File

@@ -3,9 +3,10 @@ CREATE TABLE IF NOT EXISTS
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL,
name TEXT NOT NULL,
model TEXT NOT NULL,
description TEXT,
architecture TEXT,
type TEXT NOT NULL,
hash TEXT NOT NULL,
hash TEXT,
source_url TEXT
);
@@ -23,3 +24,18 @@ CREATE TABLE IF NOT EXISTS
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
);
INSERT INTO
tags (name)
VALUES
('character'),
('style'),
('concept'),
('clothing'),
('poses'),
('background'),
('vehicle'),
('buildings'),
('objects'),
('animal'),
('action');

63
app/model_hasher.py Normal file
View File

@@ -0,0 +1,63 @@
import hashlib
import logging
import threading
import time
from comfy.cli_args import args
class ModelHasher:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._model_entity = None
def start(self, model_entity):
if args.disable_model_hashing:
return
self._model_entity = model_entity
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
if self._thread is None:
self._thread = threading.Thread(target=self._hash_models)
self._thread.daemon = True
self._thread.start()
def _get_models(self):
models = self._model_entity.get("WHERE hash IS NULL")
return models
def _hash_model(self, model_path):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(model_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
hash = h.hexdigest()
return hash
def _hash_models(self):
while True:
models = self._get_models()
if len(models) == 0:
break
for model in models:
time.sleep(0)
now = time.time()
logging.info(f"Hashing model {model['path']}")
hash = self._hash_model(model["path"])
logging.info(
f"Hashed model {model['path']} in {time.time() - now} seconds"
)
self._model_entity.update((model["id"],), {"hash": hash})
self._thread = None
model_hasher = ModelHasher()

View File

@@ -144,10 +144,12 @@ parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choic
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
parser.add_argument(
"--front-end-version",
type=str,

View File

@@ -34,6 +34,9 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
from app.database.entities import get_entity, init_entities
from app.database.db import db
from app.model_hasher import model_hasher
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@@ -682,11 +685,25 @@ class PromptServer():
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def init_db(self, routes):
init_entities(routes)
models = get_entity("models")
if db.exists:
model_hasher.start(models)
else:
def on_db_update(_, __):
model_hasher.start(models)
db.register_after_update_callback(on_db_update)
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())
self.init_db(self.routes)
# Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward

View File

@@ -0,0 +1,401 @@
from comfy.cli_args import args
args.memory_database = True # force in-memory database for testing
from typing import Callable, Optional
import pytest
import pytest_asyncio
from unittest.mock import patch
from aiohttp import web
from app.database.entities import (
column,
table,
Column,
GetEntity,
GetEntityById,
CreateEntity,
UpsertEntity,
)
from app.database.db import db
pytestmark = pytest.mark.asyncio
def create_table(entity):
# reset db
db.close()
cols: list[Column] = entity.__columns__
# Create tables as temporary so when we close the db, the tables are dropped for next test
sql = f"CREATE TEMPORARY TABLE {entity.__table_name__} ( "
for col_name, col in cols.items():
type = None
if col.type == int:
type = "INTEGER"
elif col.type == str:
type = "TEXT"
sql += f"{col_name} {type}"
if col.required:
sql += " NOT NULL"
sql += ", "
sql += f"PRIMARY KEY ({', '.join(entity.__key_columns__)})"
sql += ")"
db.execute(sql)
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
with patch.object(db, "execute", wraps=db.execute) as mock:
response = await method()
assert mock.call_args[0][0] == expected_sql
assert mock.call_args[0][1:] == expected_args
return response
@pytest.fixture
def getable_entity():
@table("getable_entity")
class GetableEntity(GetEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
nullable: Optional[str] = column(str)
return GetableEntity
@pytest.fixture
def getable_by_id_entity():
@table("getable_by_id_entity")
class GetableByIdEntity(GetEntityById):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
return GetableByIdEntity
@pytest.fixture
def getable_by_id_composite_entity():
@table("getable_by_id_composite_entity")
class GetableByIdCompositeEntity(GetEntityById):
id1: str = column(str, required=True, key=True)
id2: int = column(int, required=True, key=True)
test: str = column(str, required=True)
return GetableByIdCompositeEntity
@pytest.fixture
def creatable_entity():
@table("creatable_entity")
class CreatableEntity(CreateEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
reqd: str = column(str, required=True)
nullable: Optional[str] = column(str)
return CreatableEntity
@pytest.fixture
def upsertable_entity():
@table("upsertable_entity")
class UpsertableEntity(UpsertEntity):
id: int = column(int, required=True, key=True)
test: str = column(str, required=True)
reqd: str = column(str, required=True)
nullable: Optional[str] = column(str)
return UpsertableEntity
@pytest.fixture()
def entity(request):
value = request.getfixturevalue(request.param)
create_table(value)
return value
@pytest_asyncio.fixture
async def client(aiohttp_client, app):
return await aiohttp_client(app)
@pytest.fixture
def app(entity):
app = web.Application()
routes = web.RouteTableDef()
entity.register_route(routes)
app.add_routes(routes)
return app
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_empty_response(client):
expected_sql = "SELECT * FROM getable_entity"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
)
assert response.status == 200
assert await response.json() == []
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_data(client):
# seed db
db.execute(
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
)
expected_sql = "SELECT * FROM getable_entity"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1", "nullable": None},
{"id": 2, "test": "test2", "nullable": "test2"},
]
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_top_parameter(client):
# seed with 3 rows
db.execute(
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
)
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
expected_args = ()
response = await wrap_db(
lambda: client.get("/db/getable_entity?top=2"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1", "nullable": None},
{"id": 2, "test": "test2", "nullable": "test2"},
]
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
async def test_get_model_with_invalid_top_parameter(client):
response = await client.get("/db/getable_entity?top=hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid top parameter",
"field": "top",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
async def test_get_model_by_id_empty_response(client):
# seed db
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
expected_args = (1,)
response = await wrap_db(
lambda: client.get("/db/getable_by_id_entity/1"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id": 1, "test": "test1"},
]
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
async def test_get_model_by_id_with_invalid_id(client):
response = await client.get("/db/getable_by_id_entity/hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
async def test_get_model_by_id_composite(client):
# seed db
db.execute(
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
)
expected_sql = (
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
)
expected_args = ("one", 2)
response = await wrap_db(
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == [
{"id1": "one", "id2": 2, "test": "test"},
]
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
async def test_get_model_by_id_composite_with_invalid_id(client):
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id2",
"value": "hello",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model(client):
expected_sql = (
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
)
expected_args = (1, "test1", "reqd1")
response = await wrap_db(
lambda: client.post(
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"test": "test1",
"reqd": "reqd1",
"nullable": None,
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_missing_required_field(client):
response = await client.post(
"/db/creatable_entity", json={"id": 1, "test": "test1"}
)
assert response.status == 400
assert await response.json() == {
"message": "Missing field",
"field": "reqd",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_missing_key_field(client):
response = await client.post(
"/db/creatable_entity",
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
)
assert response.status == 400
assert await response.json() == {
"message": "Missing field",
"field": "id",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_key_data(client):
response = await client.post(
"/db/creatable_entity",
json={
"id": "not_an_integer",
"test": "test1",
"reqd": "reqd1",
}, # id should be int
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "not_an_integer",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_data(client):
response = await client.post(
"/db/creatable_entity",
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "id",
"value": "aaa",
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_type(client):
response = await client.post(
"/db/creatable_entity",
json={
"id": 1,
"test": ["invalid_array"],
"reqd": "reqd1",
}, # test should be string
)
assert response.status == 400
assert await response.json() == {
"message": "Invalid value",
"field": "test",
"value": ["invalid_array"],
}
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
async def test_create_model_invalid_field_name(client):
response = await client.post(
"/db/creatable_entity",
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
)
assert response.status == 400
assert await response.json() == {
"message": "Unknown field",
"field": "nonexistent_field",
}
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
async def test_upsert_model(client):
expected_sql = (
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
"RETURNING *"
)
expected_args = (1, "test1", "reqd1")
response = await wrap_db(
lambda: client.put(
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
),
expected_sql,
expected_args,
)
assert response.status == 200
assert await response.json() == {
"id": 1,
"test": "test1",
"reqd": "reqd1",
"nullable": None,
}