1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-04 07:52:46 +08:00

Add model management and database

- use sqlalchemy + alembic + sqlite for db
- extract model data and previews
- endpoints for db interactions
- add tests
This commit is contained in:
pythongosssss
2025-03-28 11:39:56 +08:00
parent 1709a8441e
commit 7bf381bc9e
14 changed files with 1264 additions and 40 deletions

View File

@@ -7,11 +7,33 @@ from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
from app.database.models import Base, Model, Tag
from comfy.cli_args import args
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def session():
# Configure in-memory database
args.database_url = "sqlite:///:memory:"
# Create engine and session factory
engine = create_engine(args.database_url)
Session = sessionmaker(bind=engine)
# Create all tables
Base.metadata.create_all(engine)
# Patch Session factory
with patch('app.database.db.Session', Session):
yield Session()
Base.metadata.drop_all(engine)
@pytest.fixture
def model_manager():
return ModelFileManager()
@@ -60,3 +82,287 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
# Clean up
img.close()
async def test_get_models(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/models')
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
assert data[0]['path'] == 'model1.safetensors'
assert len(data[0]['tags']) == 1
assert data[0]['tags'][0]['name'] == 'test_tag'
async def test_add_model(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
with patch('app.model_manager.model_processor') as mock_processor:
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Test Model',
'tags': [tag_id]
})
assert resp.status == 200
data = await resp.json()
assert data['path'] == 'model1.safetensors'
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
# Ensure that models are re-processed after adding
mock_processor.run.assert_called_once()
async def test_delete_model(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify model was deleted
model = session.query(Model).first()
assert model is None
async def test_delete_model_file_exists(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "file exists" in data["error"].lower()
# Verify model was not deleted
model = session.query(Model).first()
assert model is not None
assert model.path == 'model1.safetensors'
async def test_get_tags(aiohttp_client, app, session):
tags = [Tag(name='tag1'), Tag(name='tag2')]
for tag in tags:
session.add(tag)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/tags')
assert resp.status == 200
data = await resp.json()
assert len(data) == 2
assert {t['name'] for t in data} == {'tag1', 'tag2'}
async def test_create_tag(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
assert resp.status == 200
data = await resp.json()
assert data['name'] == 'new_tag'
# Verify tag was created
tag = session.query(Tag).first()
assert tag.name == 'new_tag'
async def test_delete_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/tags?id={tag_id}')
assert resp.status == 204
# Verify tag was deleted
tag = session.query(Tag).first()
assert tag is None
async def test_add_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'model1.safetensors'
})
assert resp.status == 200
data = await resp.json()
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
async def test_delete_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify tag was removed
model = session.query(Model).first()
assert len(model.tags) == 0
async def test_add_model_duplicate(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Duplicate Model'
})
assert resp.status == 400
async def test_add_model_missing_fields(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={})
assert resp.status == 400
async def test_add_tag_missing_name(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={})
assert resp.status == 400
async def test_delete_model_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
assert resp.status == 404
async def test_delete_tag_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/tags?id=999')
assert resp.status == 404
async def test_add_model_missing_path(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'title': 'Test Model'
})
assert resp.status == 400
data = await resp.json()
assert "path" in data["error"].lower()
async def test_add_model_invalid_field(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'invalid_field': 'some value'
})
assert resp.status == 400
data = await resp.json()
assert "invalid field" in data["error"].lower()
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "file" in data["error"].lower()
async def test_add_model_invalid_tag(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'tags': [999] # Non-existent tag ID
})
assert resp.status == 404
data = await resp.json()
assert "tag" in data["error"].lower()
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
# Create a tag but no model
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "model" in data["error"].lower()
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
# Create a model first
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "invalid tag id" in data["error"].lower()