mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-04 07:52:46 +08:00
Add model management and database
- use sqlalchemy + alembic + sqlite for db - extract model data and previews - endpoints for db interactions - add tests
This commit is contained in:
@@ -7,11 +7,33 @@ from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.database.models import Base, Model, Tag
|
||||
from comfy.cli_args import args
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
# Configure in-memory database
|
||||
args.database_url = "sqlite:///:memory:"
|
||||
|
||||
# Create engine and session factory
|
||||
engine = create_engine(args.database_url)
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
# Create all tables
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
# Patch Session factory
|
||||
with patch('app.database.db.Session', Session):
|
||||
yield Session()
|
||||
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
@@ -60,3 +82,287 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||
|
||||
# Clean up
|
||||
img.close()
|
||||
|
||||
async def test_get_models(aiohttp_client, app, session):
|
||||
tag = Tag(name='test_tag')
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
model.tags.append(tag)
|
||||
session.add(tag)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/v2/models')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]['path'] == 'model1.safetensors'
|
||||
assert len(data[0]['tags']) == 1
|
||||
assert data[0]['tags'][0]['name'] == 'test_tag'
|
||||
|
||||
async def test_add_model(aiohttp_client, app, session):
|
||||
tag = Tag(name='test_tag')
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
tag_id = tag.id
|
||||
|
||||
with patch('app.model_manager.model_processor') as mock_processor:
|
||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={
|
||||
'type': 'checkpoints',
|
||||
'path': 'model1.safetensors',
|
||||
'title': 'Test Model',
|
||||
'tags': [tag_id]
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data['path'] == 'model1.safetensors'
|
||||
assert len(data['tags']) == 1
|
||||
assert data['tags'][0]['name'] == 'test_tag'
|
||||
|
||||
# Ensure that models are re-processed after adding
|
||||
mock_processor.run.assert_called_once()
|
||||
|
||||
async def test_delete_model(aiohttp_client, app, session):
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
with patch('app.model_manager.get_full_path', return_value=None):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
|
||||
assert resp.status == 204
|
||||
|
||||
# Verify model was deleted
|
||||
model = session.query(Model).first()
|
||||
assert model is None
|
||||
|
||||
async def test_delete_model_file_exists(aiohttp_client, app, session):
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
|
||||
assert resp.status == 400
|
||||
|
||||
data = await resp.json()
|
||||
assert "file exists" in data["error"].lower()
|
||||
|
||||
# Verify model was not deleted
|
||||
model = session.query(Model).first()
|
||||
assert model is not None
|
||||
assert model.path == 'model1.safetensors'
|
||||
|
||||
async def test_get_tags(aiohttp_client, app, session):
|
||||
tags = [Tag(name='tag1'), Tag(name='tag2')]
|
||||
for tag in tags:
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get('/v2/tags')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert len(data) == 2
|
||||
assert {t['name'] for t in data} == {'tag1', 'tag2'}
|
||||
|
||||
async def test_create_tag(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data['name'] == 'new_tag'
|
||||
|
||||
# Verify tag was created
|
||||
tag = session.query(Tag).first()
|
||||
assert tag.name == 'new_tag'
|
||||
|
||||
async def test_delete_tag(aiohttp_client, app, session):
|
||||
tag = Tag(name='test_tag')
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
tag_id = tag.id
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete(f'/v2/tags?id={tag_id}')
|
||||
assert resp.status == 204
|
||||
|
||||
# Verify tag was deleted
|
||||
tag = session.query(Tag).first()
|
||||
assert tag is None
|
||||
|
||||
async def test_add_model_tag(aiohttp_client, app, session):
|
||||
tag = Tag(name='test_tag')
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
session.add(tag)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
tag_id = tag.id
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models/tags', json={
|
||||
'tag': tag_id,
|
||||
'type': 'checkpoints',
|
||||
'path': 'model1.safetensors'
|
||||
})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert len(data['tags']) == 1
|
||||
assert data['tags'][0]['name'] == 'test_tag'
|
||||
|
||||
async def test_delete_model_tag(aiohttp_client, app, session):
|
||||
tag = Tag(name='test_tag')
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
model.tags.append(tag)
|
||||
session.add(tag)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
tag_id = tag.id
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
|
||||
assert resp.status == 204
|
||||
|
||||
# Verify tag was removed
|
||||
model = session.query(Model).first()
|
||||
assert len(model.tags) == 0
|
||||
|
||||
async def test_add_model_duplicate(aiohttp_client, app, session):
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={
|
||||
'type': 'checkpoints',
|
||||
'path': 'model1.safetensors',
|
||||
'title': 'Duplicate Model'
|
||||
})
|
||||
assert resp.status == 400
|
||||
|
||||
async def test_add_model_missing_fields(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={})
|
||||
assert resp.status == 400
|
||||
|
||||
async def test_add_tag_missing_name(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/tags', json={})
|
||||
assert resp.status == 400
|
||||
|
||||
async def test_delete_model_not_found(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_delete_tag_not_found(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete('/v2/tags?id=999')
|
||||
assert resp.status == 404
|
||||
|
||||
async def test_add_model_missing_path(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={
|
||||
'type': 'checkpoints',
|
||||
'title': 'Test Model'
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "path" in data["error"].lower()
|
||||
|
||||
async def test_add_model_invalid_field(aiohttp_client, app, session):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={
|
||||
'type': 'checkpoints',
|
||||
'path': 'model1.safetensors',
|
||||
'invalid_field': 'some value'
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "invalid field" in data["error"].lower()
|
||||
|
||||
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
|
||||
with patch('app.model_manager.get_full_path', return_value=None):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={
|
||||
'type': 'checkpoints',
|
||||
'path': 'nonexistent.safetensors'
|
||||
})
|
||||
assert resp.status == 404
|
||||
data = await resp.json()
|
||||
assert "file" in data["error"].lower()
|
||||
|
||||
async def test_add_model_invalid_tag(aiohttp_client, app, session):
|
||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models', json={
|
||||
'type': 'checkpoints',
|
||||
'path': 'model1.safetensors',
|
||||
'tags': [999] # Non-existent tag ID
|
||||
})
|
||||
assert resp.status == 404
|
||||
data = await resp.json()
|
||||
assert "tag" in data["error"].lower()
|
||||
|
||||
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
|
||||
# Create a tag but no model
|
||||
tag = Tag(name='test_tag')
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
tag_id = tag.id
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post('/v2/models/tags', json={
|
||||
'tag': tag_id,
|
||||
'type': 'checkpoints',
|
||||
'path': 'nonexistent.safetensors'
|
||||
})
|
||||
assert resp.status == 404
|
||||
data = await resp.json()
|
||||
assert "model" in data["error"].lower()
|
||||
|
||||
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
|
||||
# Create a model first
|
||||
model = Model(
|
||||
type='checkpoints',
|
||||
path='model1.safetensors',
|
||||
title='Test Model'
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "invalid tag id" in data["error"].lower()
|
||||
|
||||
|
Reference in New Issue
Block a user