Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import torch
- from torch.utils.data import DataLoader
- from argparse import ArgumentParser
- from accelerate import Accelerator
- from accelerate.utils import set_seed
- import logging
- from dataset import VideoData
- from synthesis import SynthesisNet
- from diffusion.momo import MoMo
- from utils import set_mode, tensor2opencv
- import cv2
- import glob
- import subprocess
- from tqdm import tqdm
- # Configure logging
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument('--input_dir', type=str, required=True,
- help='Path to the directory containing video files.')
- parser.add_argument('--output_dir', type=str, required=True,
- help='Path to save the interpolated results.')
- parser.add_argument('--ckpt_path', type=str, default='./experiments/diffusion/momo_full/weights/model.pth',
- help='Path to the pretrained model weights')
- parser.add_argument('--seed', type=int, default=42, help='Random seed setting')
- parser.add_argument('--mp', type=str, default='no', choices=['fp16', 'bf16', 'no'],
- help='Use mixed precision')
- parser.add_argument('--num_workers', type=int, default=2)
- parser.add_argument('--inf_steps', type=int, default=8,
- help='Number of denoising steps to use for inference.')
- parser.add_argument('--resize_to_fit', action='store_true',
- help='Fit to training resolution and resize back to input resolution for inference.')
- parser.add_argument('--pad_to_fit_unet', action='store_true',
- help='Avoid errors in resolution mismatch after a sequence of downsamplings and upsamplings in the U-Net by padding vs resizing')
- return parser.parse_args()
- def get_average_fps(video_path):
- """
- Gets the average FPS of a video using ffprobe.
- Args:
- video_path (str): Path to the video file.
- Returns:
- float: The average FPS of the video.
- """
- cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=avg_frame_rate',
- '-of', 'default=noprint_wrappers=1:nokey=1', video_path]
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- output, error = process.communicate()
- if process.returncode != 0:
- raise RuntimeError(f'ffprobe error: {error.decode()}')
- num, den = map(int, output.decode().strip().split('/'))
- return num / den
- @torch.no_grad()
- def interpolate_video(video_path, output_path, args, accelerator, model):
- """
- Interpolates a video using the MoMo model.
- Args:
- video_path (str): Path to the input video file.
- output_path (str): Path to save the interpolated video file.
- args (argparse.Namespace): Command-line arguments.
- accelerator (accelerate.Accelerator): Accelerator object for distributed training.
- model (torch.nn.Module): The MoMo model.
- """
- video_data = VideoData(video_path)
- dataloader = DataLoader(video_data, batch_size=1, shuffle=False, num_workers=args.num_workers)
- model, dataloader = accelerator.prepare(model, dataloader)
- original_fps = get_average_fps(video_path)
- interpolated_fps = original_fps * 2
- set_seed(args.seed, device_specific=True)
- set_mode(model, mode='eval')
- output_frames = []
- total_frames = len(dataloader)
- for i, data in enumerate(tqdm(dataloader, desc="Processing frames", total=total_frames)):
- frame0, frame1 = data
- pred, _ = model(torch.stack([frame0, frame1], dim=2), num_inference_steps=args.inf_steps,
- resize_to_fit=args.resize_to_fit, pad_to_fit_unet=args.pad_to_fit_unet)
- pred = accelerator.gather_for_metrics(pred.contiguous())
- frame0 = accelerator.gather_for_metrics(frame0)
- if accelerator.is_main_process:
- output_frames.append(tensor2opencv(frame0[0].cpu()))
- output_frames.append(tensor2opencv(pred[0].cpu()))
- output_frames.append(tensor2opencv(frame1[-1].cpu()))
- if accelerator.is_main_process:
- temp_image_folder = os.path.join(args.output_dir, "temp_frames")
- os.makedirs(temp_image_folder, exist_ok=True)
- for idx, frame in enumerate(output_frames):
- cv2.imwrite(os.path.join(temp_image_folder, f"frame_{idx:04d}.bmp"), frame)
- cmd = [
- 'ffmpeg',
- '-y',
- '-framerate', str(interpolated_fps),
- '-i', os.path.join(temp_image_folder, 'frame_%04d.bmp'),
- '-i', video_path,
- '-map', '0:v',
- '-map', '1:a',
- '-c:v', 'ffv1',
- '-pix_fmt', 'rgb48',
- '-metadata:s:v:0', 'encoder=FFV1',
- '-level', '3',
- '-g', '1',
- '-slices', '24',
- '-slicecrc', '1',
- output_path
- ]
- subprocess.run(cmd)
- for file in os.listdir(temp_image_folder):
- os.remove(os.path.join(temp_image_folder, file))
- os.rmdir(temp_image_folder)
- def main():
- args = parse_args()
- accelerator = Accelerator(mixed_precision=args.mp, split_batches=False)
- accelerator.print('\n\n#######################################################################################\n')
- accelerator.print(f'x2 interpolation on videos in <{args.input_dir}>\n')
- accelerator.print(args)
- accelerator.print('\n#######################################################################################\n\n')
- synth_model = SynthesisNet()
- model = MoMo(synth_model=synth_model)
- assert os.path.exists(args.ckpt_path), 'Path to model checkpoints do not exist!'
- ckpt = torch.load(args.ckpt_path, map_location='cpu')
- param_ckpt = ckpt['model']
- model.load_state_dict(param_ckpt)
- del ckpt
- model = accelerator.prepare(model)
- # Find all files in the input directory
- for file in glob.glob(os.path.join(args.input_dir, '*')):
- # Check if the file is a video file that ffmpeg can read
- try:
- get_average_fps(file)
- except RuntimeError:
- continue
- video_name = os.path.splitext(os.path.basename(file))[0]
- output_path = os.path.join(args.output_dir, f"{video_name}_interpolated_Momo.mov")
- accelerator.print(f'Processing: {file} -> {output_path}')
- interpolate_video(file, output_path, args, accelerator, model)
- accelerator.wait_for_everyone()
- accelerator.print('Batch interpolation finished.')
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement