diff --git a/app/database/entities.py b/app/database/entities.py index c608ace33..05aaec84c 100644 --- a/app/database/entities.py +++ b/app/database/entities.py @@ -11,11 +11,14 @@ def is_primitive(obj): return isinstance(obj, primitives) -class ValidationError(Exception): - def __init__(self, message: str, field: str = None, value: Any = None): +class EntityError(Exception): + def __init__( + self, message: str, field: str = None, value: Any = None, status_code: int = 400 + ): self.message = message self.field = field self.value = value + self.status_code = status_code super().__init__(self.message) def to_json(self): @@ -33,9 +36,9 @@ class ValidationError(Exception): class EntityCommon(dict): @classmethod def _get_route(cls, include_key: bool): - route = f"/db/{cls.__table_name__}" + route = f"/db/{cls._table_name}" if include_key: - route += "".join([f"/{{{k}}}" for k in cls.__key_columns__]) + route += "".join([f"/{{{k}}}" for k in cls._key_columns]) return route @classmethod @@ -46,47 +49,54 @@ class EntityCommon(dict): async def _(request): try: data = await handler(request) + if data is None: + return web.json_response(status=204) + return web.json_response(data) - except ValidationError as e: - return web.json_response(e.to_json(), status=400) + except EntityError as e: + return web.json_response(e.to_json(), status=e.status_code) @classmethod def _transform(cls, row: list[Any]): - return {col: value for col, value in zip(cls.__columns__, row)} + return {col: value for col, value in zip(cls._columns, row)} @classmethod def _transform_rows(cls, rows: list[list[Any]]): return [cls._transform(row) for row in rows] + @classmethod + def _extract_key(cls, request): + return {key: request.match_info.get(key, None) for key in cls._key_columns} + @classmethod def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False): result = {} if not isinstance(data, dict): - raise ValidationError("Invalid data") + raise EntityError("Invalid data") # Ensure all required fields are present for field in data: if field not in fields: - raise ValidationError("Unknown field", field) + raise EntityError("Unknown field", field) for key in fields: - col = cls.__columns__[key] + col = cls._columns[key] if key not in data: if col.required and not allow_missing: - raise ValidationError("Missing field", key) + raise EntityError("Missing field", key) else: # e.g. for updates, we allow missing fields continue elif data[key] is None and col.required: # Dont allow None for required fields - raise ValidationError("Required field", key) + raise EntityError("Required field", key) # Validate data type value = data[key] if value is not None and not is_primitive(value): - raise ValidationError("Invalid value", key, value) + raise EntityError("Invalid value", key, value) try: type = col.type @@ -94,20 +104,20 @@ class EntityCommon(dict): value = type(value) result[key] = value except Exception: - raise ValidationError("Invalid value", key, value) + raise EntityError("Invalid value", key, value) return result @classmethod def _validate_id(cls, id: dict): - return cls._validate(cls.__key_columns__, id) + return cls._validate(cls._key_columns, id) @classmethod - def _validate_data(cls, data: dict): - return cls._validate(cls.__columns__.keys(), data) + def _validate_data(cls, data: dict, allow_missing: bool = False): + return cls._validate(cls._columns.keys(), data, allow_missing) def __setattr__(self, name, value): - if name in self.__columns__: + if name in self._columns: self[name] = value super().__setattr__(name, value) @@ -124,7 +134,7 @@ class GetEntity(EntityCommon): if top is not None and isinstance(top, int): limit = f" LIMIT {top}" result = db.execute( - f"SELECT * FROM {cls.__table_name__}{limit}{f' WHERE {where}' if where else ''}", + f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}", ) # Map each row in result to an instance of the class @@ -138,7 +148,7 @@ class GetEntity(EntityCommon): try: top = int(top) except Exception: - raise ValidationError("Invalid top parameter", "top", top) + raise EntityError("Invalid top parameter", "top", top) return cls.get(top) cls._register_route(routes, "get", False, get_handler) @@ -150,8 +160,8 @@ class GetEntityById(EntityCommon): id = cls._validate_id(id) result = db.execute( - f"SELECT * FROM {cls.__table_name__} WHERE {cls.__where_clause__}", - *[id[key] for key in cls.__key_columns__], + f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}", + *[id[key] for key in cls._key_columns], ) return cls._transform_rows(result) @@ -159,7 +169,7 @@ class GetEntityById(EntityCommon): @classmethod def register_route(cls, routes): async def get_by_id_handler(request): - id = {key: request.match_info.get(key, None) for key in cls.__key_columns__} + id = cls._extract_key(request) return cls.get_by_id(id) cls._register_route(routes, "get", True, get_by_id_handler) @@ -175,18 +185,18 @@ class CreateEntity(EntityCommon): data_keys = ", ".join(list(data.keys())) if allow_upsert: # Remove key columns from data - upsert_keys = [key for key in data if key not in cls.__key_columns__] + upsert_keys = [key for key in data if key not in cls._key_columns] set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys]) - on_conflict = f" ON CONFLICT ({', '.join(cls.__key_columns__)}) DO UPDATE SET {set_clause}" - sql = f"INSERT INTO {cls.__table_name__} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *" + on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}" + sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *" result = db.execute( sql, *[data[key] for key in data], ) if len(result) == 0: - raise RuntimeError("Failed to create entity") + raise EntityError("Failed to create entity", status_code=500) return cls._transform_rows(result)[0] @@ -202,7 +212,29 @@ class CreateEntity(EntityCommon): class UpdateEntity(EntityCommon): @classmethod def update(cls, id: list, data: dict): - pass + id = cls._validate_id(id) + data = cls._validate_data(data, allow_missing=True) + + sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *" + result = db.execute( + sql, + *[data[key] for key in data], + *[id[key] for key in cls._key_columns], + ) + + if len(result) == 0: + raise EntityError("Failed to update entity", status_code=404) + + return cls._transform_rows(result)[0] + + @classmethod + def register_route(cls, routes): + async def update_handler(request): + id = cls._extract_key(request) + data = await request.json() + return cls.update(id, data) + + cls._register_route(routes, "patch", True, update_handler) class UpsertEntity(CreateEntity): @@ -222,7 +254,19 @@ class UpsertEntity(CreateEntity): class DeleteEntity(EntityCommon): @classmethod def delete(cls, id: list): - pass + id = cls._validate_id(id) + db.execute( + f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}", + *[id[key] for key in cls._key_columns], + ) + + @classmethod + def register_route(cls, routes): + async def delete_handler(request): + id = cls._extract_key(request) + cls.delete(id) + + cls._register_route(routes, "delete", True, delete_handler) class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById): @@ -244,7 +288,7 @@ def column(type_: Any, required: bool = False, key: bool = False, default: Any = def table(table_name: str): def decorator(cls): # Store table name - cls.__table_name__ = table_name + cls._table_name = table_name # Process column definitions columns: dict[str, Column] = {} @@ -253,12 +297,10 @@ def table(table_name: str): columns[attr_name] = attr_value # Store columns metadata - cls.__columns__ = columns - cls.__key_columns__ = [col for col in columns if columns[col].key] - cls.__column_csv__ = ", ".join([col for col in columns]) - cls.__where_clause__ = " AND ".join( - [f"{col} = ?" for col in cls.__key_columns__] - ) + cls._columns = columns + cls._key_columns = [col for col in columns if columns[col].key] + cls._column_csv = ", ".join([col for col in columns]) + cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns]) # Add initialization original_init = cls.__init__ @@ -266,7 +308,7 @@ def table(table_name: str): @wraps(original_init) def new_init(self, *args, **kwargs): # Initialize columns with default values - for col_name, col_def in cls.__columns__.items(): + for col_name, col_def in cls._columns.items(): setattr(self, col_name, col_def.default) # Call original init original_init(self, *args, **kwargs) diff --git a/tests-unit/app_test/entities_test.py b/tests-unit/app_test/entities_test.py index d0eb59455..f9e791e86 100644 --- a/tests-unit/app_test/entities_test.py +++ b/tests-unit/app_test/entities_test.py @@ -8,6 +8,7 @@ import pytest_asyncio from unittest.mock import patch from aiohttp import web from app.database.entities import ( + DeleteEntity, column, table, Column, @@ -15,6 +16,7 @@ from app.database.entities import ( GetEntityById, CreateEntity, UpsertEntity, + UpdateEntity, ) from app.database.db import db @@ -25,9 +27,9 @@ def create_table(entity): # reset db db.close() - cols: list[Column] = entity.__columns__ + cols: list[Column] = entity._columns # Create tables as temporary so when we close the db, the tables are dropped for next test - sql = f"CREATE TEMPORARY TABLE {entity.__table_name__} ( " + sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( " for col_name, col in cols.items(): type = None if col.type == int: @@ -40,7 +42,7 @@ def create_table(entity): sql += " NOT NULL" sql += ", " - sql += f"PRIMARY KEY ({', '.join(entity.__key_columns__)})" + sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})" sql += ")" db.execute(sql) @@ -48,6 +50,7 @@ def create_table(entity): async def wrap_db(method: Callable, expected_sql: str, expected_args: list): with patch.object(db, "execute", wraps=db.execute) as mock: response = await method() + assert mock.call_count == 1 assert mock.call_args[0][0] == expected_sql assert mock.call_args[0][1:] == expected_args return response @@ -109,6 +112,35 @@ def upsertable_entity(): return UpsertableEntity +@pytest.fixture +def updateable_entity(): + @table("updateable_entity") + class UpdateableEntity(UpdateEntity): + id: int = column(int, required=True, key=True) + reqd: str = column(str, required=True) + + return UpdateableEntity + + +@pytest.fixture +def deletable_entity(): + @table("deletable_entity") + class DeletableEntity(DeleteEntity): + id: int = column(int, required=True, key=True) + + return DeletableEntity + + +@pytest.fixture +def deletable_composite_entity(): + @table("deletable_composite_entity") + class DeletableCompositeEntity(DeleteEntity): + id1: str = column(str, required=True, key=True) + id2: int = column(int, required=True, key=True) + + return DeletableCompositeEntity + + @pytest.fixture() def entity(request): value = request.getfixturevalue(request.param) @@ -399,3 +431,83 @@ async def test_upsert_model(client): "reqd": "reqd1", "nullable": None, } + + +@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True) +async def test_update_model(client): + # seed db + db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')") + + expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *" + expected_args = ("updated_test", 1) + response = await wrap_db( + lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}), + expected_sql, + expected_args, + ) + + assert response.status == 200 + assert await response.json() == { + "id": 1, + "reqd": "updated_test", + } + + +@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True) +async def test_update_model_reject_null_required_field(client): + response = await client.patch("/db/updateable_entity/1", json={"reqd": None}) + + assert response.status == 400 + assert await response.json() == { + "message": "Required field", + "field": "reqd", + } + + +@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True) +async def test_update_model_reject_invalid_field(client): + response = await client.patch("/db/updateable_entity/1", json={"hello": "world"}) + + assert response.status == 400 + assert await response.json() == { + "message": "Unknown field", + "field": "hello", + } + + +@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True) +async def test_update_model_reject_missing_record(client): + response = await client.patch( + "/db/updateable_entity/1", json={"reqd": "updated_test"} + ) + + assert response.status == 404 + assert await response.json() == { + "message": "Failed to update entity", + } + + +@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True) +async def test_delete_model(client): + expected_sql = "DELETE FROM deletable_entity WHERE id = ?" + expected_args = (1,) + response = await wrap_db( + lambda: client.delete("/db/deletable_entity/1"), + expected_sql, + expected_args, + ) + + assert response.status == 204 + + +@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True) +async def test_delete_model_composite_key(client): + expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?" + expected_args = ("one", 2) + response = await wrap_db( + lambda: client.delete("/db/deletable_composite_entity/one/2"), + expected_sql, + expected_args, + ) + + assert response.status == 204