1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 23:49:57 +08:00
Files
ComfyUI/tests-unit/app_test/model_manager_test.py
pythongosssss 7bf381bc9e Add model management and database
- use sqlalchemy + alembic + sqlite for db
- extract model data and previews
- endpoints for db interactions
- add tests
2025-03-28 11:39:56 +08:00

369 lines
12 KiB
Python

import pytest
import base64
import json
import struct
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
from app.database.models import Base, Model, Tag
from comfy.cli_args import args
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def session():
# Configure in-memory database
args.database_url = "sqlite:///:memory:"
# Create engine and session factory
engine = create_engine(args.database_url)
Session = sessionmaker(bind=engine)
# Create all tables
Base.metadata.create_all(engine)
# Patch Session factory
with patch('app.database.db.Session', Session):
yield Session()
Base.metadata.drop_all(engine)
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
img = Image.new('RGB', (100, 100), 'white')
img_byte_arr = BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
safetensors_file = tmp_path / "test_model.safetensors"
header_bytes = json.dumps({
"__metadata__": {
"ssmd_cover_images": json.dumps([img_b64])
}
}).encode('utf-8')
length_bytes = struct.pack('<Q', len(header_bytes))
with open(safetensors_file, 'wb') as f:
f.write(length_bytes)
f.write(header_bytes)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
# Verify response
assert response.status == 200
assert response.content_type == 'image/webp'
# Verify the response contains valid image data
img_bytes = BytesIO(await response.read())
img = Image.open(img_bytes)
assert img.format
assert img.format.lower() == 'webp'
# Clean up
img.close()
async def test_get_models(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/models')
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
assert data[0]['path'] == 'model1.safetensors'
assert len(data[0]['tags']) == 1
assert data[0]['tags'][0]['name'] == 'test_tag'
async def test_add_model(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
with patch('app.model_manager.model_processor') as mock_processor:
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Test Model',
'tags': [tag_id]
})
assert resp.status == 200
data = await resp.json()
assert data['path'] == 'model1.safetensors'
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
# Ensure that models are re-processed after adding
mock_processor.run.assert_called_once()
async def test_delete_model(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify model was deleted
model = session.query(Model).first()
assert model is None
async def test_delete_model_file_exists(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "file exists" in data["error"].lower()
# Verify model was not deleted
model = session.query(Model).first()
assert model is not None
assert model.path == 'model1.safetensors'
async def test_get_tags(aiohttp_client, app, session):
tags = [Tag(name='tag1'), Tag(name='tag2')]
for tag in tags:
session.add(tag)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/tags')
assert resp.status == 200
data = await resp.json()
assert len(data) == 2
assert {t['name'] for t in data} == {'tag1', 'tag2'}
async def test_create_tag(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
assert resp.status == 200
data = await resp.json()
assert data['name'] == 'new_tag'
# Verify tag was created
tag = session.query(Tag).first()
assert tag.name == 'new_tag'
async def test_delete_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/tags?id={tag_id}')
assert resp.status == 204
# Verify tag was deleted
tag = session.query(Tag).first()
assert tag is None
async def test_add_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'model1.safetensors'
})
assert resp.status == 200
data = await resp.json()
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
async def test_delete_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify tag was removed
model = session.query(Model).first()
assert len(model.tags) == 0
async def test_add_model_duplicate(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Duplicate Model'
})
assert resp.status == 400
async def test_add_model_missing_fields(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={})
assert resp.status == 400
async def test_add_tag_missing_name(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={})
assert resp.status == 400
async def test_delete_model_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
assert resp.status == 404
async def test_delete_tag_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/tags?id=999')
assert resp.status == 404
async def test_add_model_missing_path(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'title': 'Test Model'
})
assert resp.status == 400
data = await resp.json()
assert "path" in data["error"].lower()
async def test_add_model_invalid_field(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'invalid_field': 'some value'
})
assert resp.status == 400
data = await resp.json()
assert "invalid field" in data["error"].lower()
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "file" in data["error"].lower()
async def test_add_model_invalid_tag(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'tags': [999] # Non-existent tag ID
})
assert resp.status == 404
data = await resp.json()
assert "tag" in data["error"].lower()
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
# Create a tag but no model
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "model" in data["error"].lower()
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
# Create a model first
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "invalid tag id" in data["error"].lower()