diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py index bb936e0a4..5d95dc507 100644 --- a/comfy_api/input/video_types.py +++ b/comfy_api/input/video_types.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional, Union import io +import av from comfy_api.util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): @@ -70,3 +71,15 @@ class VideoInput(ABC): components = self.get_components() frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + # Default implementation - subclasses should override for better performance + source = self.get_stream_source() + with av.open(source, mode="r") as container: + return container.format.name diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index 9ae818f4e..91e7c1bfa 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -121,6 +121,18 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode='r') as container: + return container.format.name + def get_components_internal(self, container: InputContainer) -> VideoComponents: # Get video frames frames = [] diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 057021efa..789fcef02 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -5,7 +5,6 @@ import torch from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, validate_image_dimensions, - validate_video_dimensions, ) @@ -176,54 +175,76 @@ def validate_input_image( ) -def validate_input_video( - video: VideoInput, num_frames_out: int, with_frame_conditioning: bool = False -): +def validate_video_to_video_input(video: VideoInput) -> VideoInput: + """ + Validates and processes video input for Moonvalley Video-to-Video generation. + + Args: + video: Input video to validate + + Returns: + Validated and potentially trimmed video + + Raises: + ValueError: If video doesn't meet requirements + MoonvalleyApiError: If video duration is too short + """ + width, height = _get_video_dimensions(video) + _validate_video_dimensions(width, height) + _validate_container_format(video) + + return _validate_and_trim_duration(video) + + +def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: + """Extracts video dimensions with error handling.""" try: - width, height = video.get_dimensions() + return video.get_dimensions() except Exception as e: logging.error("Error getting dimensions of video: %s", e) raise ValueError(f"Cannot get video dimensions: {e}") from e - validate_input_media(width, height, with_frame_conditioning) - validate_video_dimensions( - video, - min_width=MIN_VID_WIDTH, - min_height=MIN_VID_HEIGHT, - max_width=MAX_VID_WIDTH, - max_height=MAX_VID_HEIGHT, - ) - trimmed_video = validate_input_video_length(video, num_frames_out) - return trimmed_video +def _validate_video_dimensions(width: int, height: int) -> None: + """Validates video dimensions meet Moonvalley V2V requirements.""" + supported_resolutions = { + (1920, 1080), (1080, 1920), (1152, 1152), + (1536, 1152), (1152, 1536) + } + + if (width, height) not in supported_resolutions: + supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") -def validate_input_video_length(video: VideoInput, num_frames: int): +def _validate_container_format(video: VideoInput) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") - if video.get_duration() > 60: - raise MoonvalleyApiError( - "Input Video lenth should be less than 1min. Please trim." - ) - if num_frames == 128: - if video.get_duration() < 5: - raise MoonvalleyApiError( - "Input Video length is less than 5s. Please use a video longer than or equal to 5s." - ) - if video.get_duration() > 5: - # trim video to 5s - video = trim_video(video, 5) - if num_frames == 256: - if video.get_duration() < 10: - raise MoonvalleyApiError( - "Input Video length is less than 10s. Please use a video longer than or equal to 10s." - ) - if video.get_duration() > 10: - # trim video to 10s - video = trim_video(video, 10) +def _validate_and_trim_duration(video: VideoInput) -> VideoInput: + """Validates video duration and trims to 5 seconds if needed.""" + duration = video.get_duration() + _validate_minimum_duration(duration) + return _trim_if_too_long(video, duration) + + +def _validate_minimum_duration(duration: float) -> None: + """Ensures video is at least 5 seconds long.""" + if duration < 5: + raise MoonvalleyApiError("Input video must be at least 5 seconds long.") + + +def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: + """Trims video to 5 seconds if longer.""" + if duration > 5: + return trim_video(video, 5) return video + def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: """ Returns a new VideoInput object trimmed from the beginning to the specified duration, @@ -278,15 +299,13 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels" ) - # Calculate target frame count that's divisible by 32 + # Calculate target frame count that's divisible by 16 fps = input_container.streams.video[0].average_rate estimated_frames = int(duration_sec * fps) - target_frames = ( - estimated_frames // 32 - ) * 32 # Round down to nearest multiple of 32 + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 if target_frames == 0: - raise ValueError("Video too short: need at least 32 frames for Moonvalley") + raise ValueError("Video too short: need at least 16 frames for Moonvalley") frame_count = 0 audio_frame_count = 0 @@ -353,8 +372,8 @@ class BaseMoonvalleyVideoNode: "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1440 x 1080)": {"width": 1440, "height": 1080}, - "3:4 (1080 x 1440)": {"width": 1080, "height": 1440}, + "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, + "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, } if resolution in res_map: @@ -494,7 +513,6 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): image = kwargs.get("image", None) if image is None: raise MoonvalleyApiError("image is required") - total_frames = get_total_frames_from_length() validate_input_image(image, True) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) @@ -505,7 +523,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): steps=kwargs.get("steps"), seed=kwargs.get("seed"), guidance_scale=kwargs.get("prompt_adherence"), - num_frames=total_frames, + num_frames=128, width=width_height.get("width"), height=width_height.get("height"), use_negative_prompts=True, @@ -549,39 +567,45 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): @classmethod def INPUT_TYPES(cls): - input_types = super().INPUT_TYPES() - for param in ["resolution", "image"]: - if param in input_types["required"]: - del input_types["required"][param] - if param in input_types["optional"]: - del input_types["optional"][param] - input_types["optional"] = { - "video": ( - IO.VIDEO, - { - "default": "", - "multiline": False, - "tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically.", - }, - ), - "control_type": ( - ["Motion Transfer", "Pose Transfer"], - {"default": "Motion Transfer"}, - ), - "motion_intensity": ( - "INT", - { - "default": 100, - "step": 1, - "min": 0, - "max": 100, - "tooltip": "Only used if control_type is 'Motion Transfer'", - }, - ), + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text", + multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + MoonvalleyVideoToVideoInferenceParams, + "negative_prompt", + multiline=True, + default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts" + ), + "seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + "optional": { + "video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}), + "control_type": ( + ["Motion Transfer", "Pose Transfer"], + {"default": "Motion Transfer"}, + ), + "motion_intensity": ( + "INT", + { + "default": 100, + "step": 1, + "min": 0, + "max": 100, + "tooltip": "Only used if control_type is 'Motion Transfer'", + }, + ) + } } - return input_types - RETURN_TYPES = ("VIDEO",) RETURN_NAMES = ("video",) @@ -589,15 +613,13 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") - num_frames = get_total_frames_from_length() if not video: raise MoonvalleyApiError("video is required") - """Validate video input""" video_url = "" if video: - validated_video = validate_input_video(video, num_frames, False) + validated_video = validate_video_to_video_input(video) video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) control_type = kwargs.get("control_type") @@ -605,12 +627,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): """Validate prompts and inference input""" validate_prompts(prompt, negative_prompt) - inference_params = MoonvalleyVideoToVideoInferenceParams( + + # Only include motion_intensity for Motion Transfer + control_params = {} + if control_type == "Motion Transfer" and motion_intensity is not None: + control_params['motion_intensity'] = motion_intensity + + inference_params=MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, - steps=kwargs.get("steps"), seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - control_params={"motion_intensity": motion_intensity}, + control_params=control_params ) control = self.parseControlParameter(control_type) @@ -667,17 +693,16 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): ): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) - num_frames = get_total_frames_from_length() - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - num_frames=num_frames, - width=width_height.get("width"), - height=width_height.get("height"), - ) + inference_params=MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=kwargs.get("steps"), + seed=kwargs.get("seed"), + guidance_scale=kwargs.get("prompt_adherence"), + num_frames=128, + width=width_height.get("width"), + height=width_height.get("height"), + ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) @@ -707,22 +732,12 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): NODE_CLASS_MAPPINGS = { "MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode, "MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode, - # "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, + "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, } NODE_DISPLAY_NAME_MAPPINGS = { "MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video", "MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video", - # "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", + "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", } - - -def get_total_frames_from_length(length="5s"): - # if length == '5s': - # return 128 - # elif length == '10s': - # return 256 - return 128 - # else: - # raise MoonvalleyApiError("length is required")