1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 23:49:57 +08:00

Compare commits

...

11 Commits

Author SHA1 Message Date
pythongosssss
d7062277a7 fix bad merge 2025-08-03 16:40:27 +01:00
pythongosssss
54cf14cbbb Merge remote-tracking branch 'origin/master' into pysssss-model-db 2025-08-03 16:36:49 +01:00
comfyanonymous
aebac22193 Cleanup. (#9160) 2025-08-03 07:08:11 -04:00
comfyanonymous
13aaa66ec2 Make sure context is on the right device. (#9154) 2025-08-02 15:09:23 -04:00
comfyanonymous
5f582a9757 Make sure all the conds are on the right device. (#9151) 2025-08-02 15:00:13 -04:00
ComfyUI Wiki
fbcc23945d Update template to 0.1.47 (#9153) 2025-08-02 14:15:29 -04:00
Johnpaul Chiwetelu
3dfefc88d0 API for Recently Used Items (#8792)
* feat: add file creation time to model file metadata and user file info

* fix linting
2025-08-01 22:02:06 -04:00
pythongosssss
7d5160f92c Tidy 2025-06-01 15:45:15 +01:00
pythongosssss
7f7b3f1695 tidy 2025-06-01 15:41:00 +01:00
pythongosssss
9da6aca0d0 Add additional db model metadata fields and model downloading function 2025-06-01 15:32:13 +01:00
pythongosssss
1cb3c98947 Implement database & model hashing 2025-06-01 15:32:02 +01:00
14 changed files with 790 additions and 30 deletions

View File

@@ -0,0 +1,40 @@
"""init
Revision ID: e9c714da8d57
Revises:
Create Date: 2025-05-30 20:14:33.772039
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e9c714da8d57'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('file_name', sa.Text(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('hash_algorithm', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model')
# ### end Alembic commands ###

View File

@@ -1,4 +1,11 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
@@ -11,4 +18,42 @@ def to_dict(obj):
if (val := getattr(obj, field))
}
# TODO: Define models here
class Model(Base):
"""
sqlalchemy model representing a model file in the system.
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
def to_dict(self):
"""
Convert the model instance to a dictionary representation.
Returns:
dict: A dictionary containing the attributes of the model
"""
dict = to_dict(self)
return dict

View File

@@ -196,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
"""
A class to manage ComfyUI frontend versions and installations.
This class handles the initialization and management of different frontend versions,
including the default frontend from the pip package and custom frontend versions
from GitHub repositories.
Attributes:
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
"""
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
@@ -205,6 +216,15 @@ class FrontendManager:
@classmethod
def default_frontend_path(cls) -> str:
"""
Get the path to the default frontend installation from the pip package.
Returns:
str: The path to the default frontend static files.
Raises:
SystemExit: If the comfyui-frontend-package is not installed.
"""
try:
import comfyui_frontend_package
@@ -225,6 +245,15 @@ comfyui-frontend-package is not installed.
@classmethod
def templates_path(cls) -> str:
"""
Get the path to the workflow templates.
Returns:
str: The path to the workflow templates directory.
Raises:
SystemExit: If the comfyui-workflow-templates package is not installed.
"""
try:
import comfyui_workflow_templates
@@ -260,11 +289,16 @@ comfyui-workflow-templates is not installed.
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Parse a version string into its components.
The version string should be in the format: 'owner/repo@version'
where version can be either a semantic version (v1.2.3) or 'latest'.
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
tuple[str, str, str]: A tuple containing (owner, repo, version).
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
@@ -281,18 +315,22 @@ comfyui-workflow-templates is not installed.
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initializes the frontend for the specified version.
Initialize a frontend version without error handling.
This method attempts to initialize a specific frontend version, either from
the default pip package or from a custom GitHub repository. It will download
and extract the frontend files if necessary.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
version_string (str): The version string specifying which frontend to use.
provider (FrontEndProvider, optional): The provider to use for custom frontends.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
Exception: If there is an error during initialization (e.g., network timeout,
invalid URL, or missing assets).
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
@@ -344,13 +382,17 @@ comfyui-workflow-templates is not installed.
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Initialize a frontend version with error handling.
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
with error handling, falling back to the default frontend if initialization fails.
Args:
version_string (str): The version string to initialize the frontend with.
version_string (str): The version string specifying which frontend to use.
Returns:
str: The path of the initialized frontend.
str: The path to the initialized frontend. If initialization fails,
returns the path to the default frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)

View File

@@ -130,10 +130,21 @@ class ModelFileManager:
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
@@ -144,7 +155,7 @@ class ModelFileManager:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)

331
app/model_processor.py Normal file
View File

@@ -0,0 +1,331 @@
import os
import logging
import time
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
result = get_relative_path(model_path)
if not result:
logging.error(
f"Model file not in a recognized model directory: {model_path}"
)
return None
return result
except Exception as e:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
def _get_existing_model(self, session, model_type, model_relative_path):
return (
session.query(Model)
.filter(Model.type == model_type)
.filter(Model.path == model_relative_path)
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if not model:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
logging.error(
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
return None
model_type, model_relative_path = result
with create_session() as session:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if model and model.hash:
return model.hash
return None
except Exception as e:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()

View File

@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
path: str
size: int
modified: int
created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
"modified": os.path.getmtime(path),
"created": os.path.getctime(path)
}

View File

@@ -210,6 +210,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@@ -43,7 +43,6 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor

View File

@@ -106,10 +106,12 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = extra.to(dtype=dtype, device=device)
else:
extra = extra.to(device=device)
return extra
@@ -169,20 +171,21 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = context.to(dtype=dtype, device=device)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype)
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra

View File

@@ -50,10 +50,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
@@ -66,6 +72,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@@ -93,6 +101,13 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
try:
from app.model_processor import model_processor
model_processor.process_file(ckpt)
except Exception as e:
logging.error(f"Error processing file {ckpt}: {e}")
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@@ -275,10 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@@ -291,6 +288,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return full_path
elif os.path.islink(full_path):
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
elif allow_missing:
return full_path
return None
@@ -305,6 +304,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
return full_path
def get_relative_path(full_path: str) -> tuple[str, str] | None:
"""Convert a full path back to a type-relative path.
Args:
full_path: The full path to the file
Returns:
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
"""
global folder_names_and_paths
full_path = os.path.normpath(full_path)
for model_type, (paths, _) in folder_names_and_paths.items():
for base_path in paths:
base_path = os.path.normpath(base_path)
if full_path.startswith(base_path):
relative_path = os.path.relpath(full_path, base_path)
return model_type, relative_path
return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths

View File

@@ -164,7 +164,6 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_type = execution.CacheType.CLASSIC
@@ -279,7 +278,6 @@ def cleanup_temp():
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
def setup_database():
try:
from app.database.db import init_db, dependencies_available
@@ -288,7 +286,6 @@ def setup_database():
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.45
comfyui-workflow-templates==0.1.47
comfyui-embedded-docs==0.2.4
torch
torchsde
@@ -20,6 +20,7 @@ tqdm
psutil
alembic
SQLAlchemy
blake3
#non essential dependencies:
kornia>=0.7.1

View File

@@ -0,0 +1,253 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.model_processor import ModelProcessor
from app.database.models import Model, Base
import os
# Test data constants
TEST_MODEL_TYPE = "checkpoints"
TEST_URL = "http://example.com/model.safetensors"
TEST_FILE_NAME = "model.safetensors"
TEST_EXPECTED_HASH = "abc123"
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
"""Helper to create a test model in the database."""
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
session.add(model)
session.commit()
return model
def setup_mock_hash_calculation(model_processor, hash_value):
"""Helper to setup hash calculation mocks."""
mock_hash = MagicMock()
mock_hash.hexdigest.return_value = hash_value
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
"""Helper to verify model exists in database with correct attributes."""
db_model = session.query(Model).filter_by(path=file_name).first()
assert db_model is not None
if expected_hash:
assert db_model.hash == expected_hash
if expected_type:
assert db_model.type == expected_type
return db_model
@pytest.fixture
def db_engine():
# Configure in-memory database
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(db_engine):
Session = sessionmaker(bind=db_engine)
session = Session()
yield session
session.close()
@pytest.fixture
def mock_get_relative_path():
with patch("app.model_processor.get_relative_path") as mock:
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
yield mock
@pytest.fixture
def mock_get_full_path():
with patch("app.model_processor.get_full_path") as mock:
mock.return_value = TEST_DESTINATION_PATH
yield mock
@pytest.fixture
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
with patch("app.model_processor.create_session", return_value=db_session):
with patch("app.model_processor.can_create_session", return_value=True):
processor = ModelProcessor()
# Setup test state
processor.removed_files = []
processor.downloaded_files = []
processor.file_exists = {}
def mock_download_file(url, destination_path, hasher):
processor.downloaded_files.append((url, destination_path))
processor.file_exists[destination_path] = True
# Simulate writing some data to the file
test_data = b"test data"
hasher.update(test_data)
def mock_remove_file(file_path):
processor.removed_files.append(file_path)
if file_path in processor.file_exists:
del processor.file_exists[file_path]
# Setup common patches
file_exists_patch = patch.object(
processor,
"_file_exists",
side_effect=lambda path: processor.file_exists.get(path, False),
)
file_size_patch = patch.object(
processor,
"_get_file_size",
side_effect=lambda path: (
1000 if processor.file_exists.get(path, False) else 0
),
)
download_file_patch = patch.object(
processor, "_download_file", side_effect=mock_download_file
)
remove_file_patch = patch.object(
processor, "_remove_file", side_effect=mock_remove_file
)
with (
file_exists_patch,
file_size_patch,
download_file_patch,
remove_file_patch,
):
yield processor
def test_ensure_downloaded_invalid_extension(model_processor):
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
# Ensure that a file with the same hash but from a different source is not downloaded again
SOURCE_URL = "https://example.com/other.sft"
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
model_processor.file_exists[TEST_DESTINATION_PATH] = True
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
# Ensure that a file with a different hash raises an error
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
def test_ensure_downloaded_new_file(model_processor, db_session):
# Ensure that a new file is downloaded
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
assert len(model_processor.downloaded_files) == 1
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
assert model_processor.file_exists[TEST_DESTINATION_PATH]
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
# Ensure that download that results in a different hash raises an error
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, "different_hash"):
with pytest.raises(
ValueError,
match="Downloaded file hash .* does not match expected hash .*",
):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE,
TEST_URL,
TEST_FILE_NAME,
TEST_EXPECTED_HASH,
)
assert len(model_processor.removed_files) == 1
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
assert TEST_DESTINATION_PATH not in model_processor.file_exists
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
def test_process_file_without_hash(model_processor, db_session):
# Test processing file without provided hash
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
result = model_processor.process_file(TEST_DESTINATION_PATH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash(model_processor, db_session):
# Test retrieving model by hash
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
# Test retrieving model by hash and type
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.type == TEST_MODEL_TYPE
def test_retrieve_hash(model_processor, db_session):
# Test retrieving hash for existing model
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
with patch.object(
model_processor,
"_validate_path",
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
):
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
assert result == TEST_EXPECTED_HASH
def test_validate_file_extension_valid_extensions(model_processor):
# Test all valid file extensions
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
for ext in valid_extensions:
model_processor._validate_file_extension(f"test{ext}") # Should not raise
def test_process_file_existing_without_source_url(model_processor, db_session):
# Test processing an existing file that needs its source URL updated
model_processor.file_exists[TEST_DESTINATION_PATH] = True
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.source_url == TEST_URL
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
assert db_model.source_url == TEST_URL