Advertisement
zelenooki87

Batch script for MoMo diffusion frame interpolation 2x

Jun 27th, 2024 (edited)
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.63 KB | Science | 0 0
  1. import os
  2. import torch
  3. from torch.utils.data import DataLoader
  4. from argparse import ArgumentParser
  5. from accelerate import Accelerator
  6. from accelerate.utils import set_seed
  7. import logging
  8.  
  9. from dataset import VideoData
  10. from synthesis import SynthesisNet
  11. from diffusion.momo import MoMo
  12. from utils import set_mode, tensor2opencv
  13. import cv2
  14. import glob
  15. import subprocess
  16. from tqdm import tqdm
  17.  
  18. # Configure logging
  19. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  20.  
  21. def parse_args():
  22.     parser = ArgumentParser()
  23.     parser.add_argument('--input_dir', type=str, required=True,
  24.                         help='Path to the directory containing video files.')
  25.     parser.add_argument('--output_dir', type=str, required=True,
  26.                         help='Path to save the interpolated results.')
  27.     parser.add_argument('--ckpt_path', type=str, default='./experiments/diffusion/momo_full/weights/model.pth',
  28.                         help='Path to the pretrained model weights')
  29.     parser.add_argument('--seed', type=int, default=42, help='Random seed setting')
  30.     parser.add_argument('--mp', type=str, default='no', choices=['fp16', 'bf16', 'no'],
  31.                         help='Use mixed precision')
  32.     parser.add_argument('--num_workers', type=int, default=2)
  33.     parser.add_argument('--inf_steps', type=int, default=8,
  34.                         help='Number of denoising steps to use for inference.')
  35.     parser.add_argument('--resize_to_fit', action='store_true',
  36.                         help='Fit to training resolution and resize back to input resolution for inference.')
  37.     parser.add_argument('--pad_to_fit_unet', action='store_true',
  38.                         help='Avoid errors in resolution mismatch after a sequence of downsamplings and upsamplings in the U-Net by padding vs resizing')
  39.     return parser.parse_args()
  40.  
  41. def get_average_fps(video_path):
  42.     """
  43.    Gets the average FPS of a video using ffprobe.
  44.  
  45.    Args:
  46.        video_path (str): Path to the video file.
  47.  
  48.    Returns:
  49.        float: The average FPS of the video.
  50.    """
  51.     cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=avg_frame_rate',
  52.            '-of', 'default=noprint_wrappers=1:nokey=1', video_path]
  53.     process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  54.     output, error = process.communicate()
  55.     if process.returncode != 0:
  56.         raise RuntimeError(f'ffprobe error: {error.decode()}')
  57.     num, den = map(int, output.decode().strip().split('/'))
  58.     return num / den
  59.  
  60. @torch.no_grad()
  61. def interpolate_video(video_path, output_path, args, accelerator, model):
  62.     """
  63.    Interpolates a video using the MoMo model.
  64.  
  65.    Args:
  66.        video_path (str): Path to the input video file.
  67.        output_path (str): Path to save the interpolated video file.
  68.        args (argparse.Namespace): Command-line arguments.
  69.        accelerator (accelerate.Accelerator): Accelerator object for distributed training.
  70.        model (torch.nn.Module): The MoMo model.
  71.    """
  72.     video_data = VideoData(video_path)
  73.     dataloader = DataLoader(video_data, batch_size=1, shuffle=False, num_workers=args.num_workers)
  74.     model, dataloader = accelerator.prepare(model, dataloader)
  75.  
  76.     original_fps = get_average_fps(video_path)
  77.     interpolated_fps = original_fps * 2
  78.  
  79.     set_seed(args.seed, device_specific=True)
  80.     set_mode(model, mode='eval')
  81.     output_frames = []
  82.     total_frames = len(dataloader)
  83.  
  84.     for i, data in enumerate(tqdm(dataloader, desc="Processing frames", total=total_frames)):
  85.         frame0, frame1 = data
  86.         pred, _ = model(torch.stack([frame0, frame1], dim=2), num_inference_steps=args.inf_steps,
  87.                        resize_to_fit=args.resize_to_fit, pad_to_fit_unet=args.pad_to_fit_unet)
  88.         pred = accelerator.gather_for_metrics(pred.contiguous())
  89.         frame0 = accelerator.gather_for_metrics(frame0)
  90.  
  91.         if accelerator.is_main_process:
  92.             output_frames.append(tensor2opencv(frame0[0].cpu()))
  93.             output_frames.append(tensor2opencv(pred[0].cpu()))
  94.  
  95.     output_frames.append(tensor2opencv(frame1[-1].cpu()))
  96.  
  97.     if accelerator.is_main_process:
  98.         temp_image_folder = os.path.join(args.output_dir, "temp_frames")
  99.         os.makedirs(temp_image_folder, exist_ok=True)
  100.  
  101.         for idx, frame in enumerate(output_frames):
  102.             cv2.imwrite(os.path.join(temp_image_folder, f"frame_{idx:04d}.bmp"), frame)
  103.  
  104.         cmd = [
  105.             'ffmpeg',
  106.             '-y',
  107.             '-framerate', str(interpolated_fps),
  108.             '-i', os.path.join(temp_image_folder, 'frame_%04d.bmp'),
  109.             '-i', video_path,
  110.             '-map', '0:v',
  111.             '-map', '1:a',
  112.             '-c:v', 'ffv1',
  113.             '-pix_fmt', 'rgb48',
  114.             '-metadata:s:v:0', 'encoder=FFV1',
  115.             '-level', '3',
  116.             '-g', '1',
  117.             '-slices', '24',
  118.             '-slicecrc', '1',
  119.             output_path
  120.         ]
  121.         subprocess.run(cmd)
  122.  
  123.         for file in os.listdir(temp_image_folder):
  124.             os.remove(os.path.join(temp_image_folder, file))
  125.         os.rmdir(temp_image_folder)
  126.  
  127. def main():
  128.     args = parse_args()
  129.     accelerator = Accelerator(mixed_precision=args.mp, split_batches=False)
  130.  
  131.     accelerator.print('\n\n#######################################################################################\n')
  132.     accelerator.print(f'x2 interpolation on videos in <{args.input_dir}>\n')
  133.     accelerator.print(args)
  134.     accelerator.print('\n#######################################################################################\n\n')
  135.  
  136.     synth_model = SynthesisNet()
  137.     model = MoMo(synth_model=synth_model)
  138.     assert os.path.exists(args.ckpt_path), 'Path to model checkpoints do not exist!'
  139.     ckpt = torch.load(args.ckpt_path, map_location='cpu')
  140.     param_ckpt = ckpt['model']
  141.     model.load_state_dict(param_ckpt)
  142.     del ckpt
  143.  
  144.     model = accelerator.prepare(model)
  145.  
  146.     # Find all files in the input directory
  147.     for file in glob.glob(os.path.join(args.input_dir, '*')):
  148.         # Check if the file is a video file that ffmpeg can read
  149.         try:
  150.             get_average_fps(file)
  151.         except RuntimeError:
  152.             continue
  153.  
  154.         video_name = os.path.splitext(os.path.basename(file))[0]
  155.         output_path = os.path.join(args.output_dir, f"{video_name}_interpolated_Momo.mov")
  156.  
  157.         accelerator.print(f'Processing: {file} -> {output_path}')
  158.         interpolate_video(file, output_path, args, accelerator, model)
  159.  
  160.     accelerator.wait_for_everyone()
  161.     accelerator.print('Batch interpolation finished.')
  162.  
  163. if __name__ == '__main__':
  164.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement