Advertisement
zelenooki87

OSEDiff video upscaling script

Aug 9th, 2024 (edited)
238
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.70 KB | Science | 0 0
  1. import os
  2. import sys
  3. sys.path.append(os.getcwd())
  4. import glob
  5. import argparse
  6. import torch
  7. from torchvision import transforms
  8. import torchvision.transforms.functional as F
  9. import numpy as np
  10. from PIL import Image
  11. import subprocess
  12. import time
  13.  
  14. from osediff import OSEDiff_test
  15. from my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
  16.  
  17. from ram.models.ram_lora import ram
  18. from ram import inference_ram as inference
  19.  
  20. tensor_transforms = transforms.Compose([
  21.                 transforms.ToTensor(),
  22.             ])
  23.  
  24. ram_transforms = transforms.Compose([
  25.             transforms.Resize((384, 384)),
  26.             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  27.         ])
  28.  
  29.  
  30. def get_validation_prompt(args, image, model, device='cuda'):
  31.     validation_prompt = ""
  32.     lq = tensor_transforms(image).unsqueeze(0).to(device)
  33.     lq = ram_transforms(lq).to(dtype=weight_dtype)
  34.     captions = inference(lq, model)
  35.     validation_prompt = f"{captions[0]}, {args.prompt},"
  36.    
  37.     return validation_prompt
  38.  
  39. def get_average_fps(video_path):
  40.     """
  41.    Gets the average FPS of a video using ffprobe.
  42.  
  43.    Args:
  44.        video_path (str): Path to the video file.
  45.  
  46.    Returns:
  47.        float: The average FPS of the video.
  48.    """
  49.     cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=avg_frame_rate',
  50.            '-of', 'default=noprint_wrappers=1:nokey=1', video_path]
  51.     process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
  52.     output, error = process.communicate()
  53.     if process.returncode != 0:
  54.         raise RuntimeError(f'ffprobe error: {error}')
  55.     num, den = map(int, output.strip().split('/'))
  56.     return num / den
  57.  
  58. def process_video(args, model, DAPE, weight_dtype, input_video_path, processing_times):
  59.     # Extract frames and get input framerate
  60.     temp_dir = os.path.join(args.output_dir, "temp_frames", os.path.splitext(os.path.basename(input_video_path))[0])
  61.     os.makedirs(temp_dir, exist_ok=True)
  62.  
  63.     # Use ffprobe to get average framerate
  64.     avg_framerate = get_average_fps(input_video_path)
  65.  
  66.     # Extract frames in BMP format
  67.     subprocess.run([
  68.         'ffmpeg', '-i', input_video_path, '-vf', f'fps={avg_framerate}', f'{temp_dir}/frame%04d.bmp'
  69.     ])
  70.  
  71.     # Process each frame
  72.     image_names = sorted(glob.glob(f'{temp_dir}/*.bmp'))
  73.     start_time = time.time()
  74.     for image_name in image_names:
  75.         input_image = Image.open(image_name).convert('RGB')
  76.         ori_width, ori_height = input_image.size
  77.         rscale = args.upscale
  78.         resize_flag = False
  79.         if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
  80.             scale = (args.process_size//rscale)/min(ori_width, ori_height)
  81.             input_image = input_image.resize((int(scale*ori_width), int(scale*ori_height)))
  82.             resize_flag = True
  83.         input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
  84.  
  85.         new_width = input_image.width - input_image.width % 8
  86.         new_height = input_image.height - input_image.height % 8
  87.         input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
  88.         bname = os.path.basename(image_name)
  89.  
  90.         # get caption
  91.         validation_prompt = get_validation_prompt(args, input_image, DAPE)
  92.         print(f"process {image_name}, tag: {validation_prompt}".encode('utf-8'))
  93.  
  94.         # translate the image
  95.         with torch.no_grad():
  96.             lq = F.to_tensor(input_image).unsqueeze(0).cuda()*2-1
  97.             output_image = model(lq, prompt=validation_prompt)
  98.             output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
  99.             if args.align_method == 'adain':
  100.                 output_pil = adain_color_fix(target=output_pil, source=input_image)
  101.             elif args.align_method == 'wavelet':
  102.                 output_pil = wavelet_color_fix(target=output_pil, source=input_image)
  103.             else:
  104.                 pass
  105.             if resize_flag:
  106.                 output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
  107.  
  108.         output_pil.save(os.path.join(temp_dir, bname))
  109.  
  110.     # Combine frames back into video using the detected average framerate
  111.     output_video_path = os.path.join(args.output_dir, os.path.splitext(os.path.basename(input_video_path))[0] + "_upscaled.mov")
  112.     subprocess.run([
  113.         'ffmpeg',
  114.         '-y', # Overwrite output files without asking
  115.         '-framerate', str(avg_framerate),
  116.         '-i', f'{temp_dir}/frame%04d.bmp',
  117.         '-i', input_video_path,
  118.         '-map', '0:v',
  119.         '-map', '1:a',
  120.         '-c:v', 'ffv1',
  121.         '-pix_fmt', 'rgb48',
  122.         '-metadata:s:v:0', 'encoder=FFV1',
  123.         '-level', '3',
  124.         '-g', '1',
  125.         '-slices', '24',
  126.         '-slicecrc', '1',
  127.         output_video_path
  128.     ])
  129.  
  130.     # Clean up temporary frames
  131.     subprocess.run(['rm', '-r', temp_dir])
  132.  
  133.     end_time = time.time()
  134.     processing_time = end_time - start_time
  135.     processing_times.append((input_video_path, processing_time))
  136.  
  137. if __name__ == "__main__":
  138.     parser = argparse.ArgumentParser()
  139.     parser.add_argument('--input_image', '-i', type=str, default='input', help='path to the input image or video')
  140.     parser.add_argument('--output_dir', '-o', type=str, default='output', help='the directory to save the output')
  141.     parser.add_argument('--pretrained_model_name_or_path', type=str, default='preset/models/stable-diffusion-2-1-base', help='sd model path')
  142.     parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
  143.     parser.add_argument("--process_size", type=int, default=512)
  144.     parser.add_argument("--upscale", type=int, default=4)
  145.     parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
  146.     parser.add_argument("--osediff_path", type=str, default='preset/models/osediff.pkl')
  147.     parser.add_argument('--prompt', type=str, default='', help='user prompts')
  148.     parser.add_argument('--ram_path', type=str, default='preset/models/ram_swin_large_14m.pth')
  149.     parser.add_argument('--ram_ft_path', type=str, default='preset/models/DAPE.pth')
  150.     parser.add_argument('--save_prompts', type=bool, default=True)
  151.     # precision setting
  152.     parser.add_argument("--mixed_precision", type=str, choices=['fp16', 'fp32'], default="fp16")
  153.     # merge lora
  154.     parser.add_argument("--merge_and_unload_lora", default=False) # merge lora weights before inference
  155.     # tile setting
  156.     parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
  157.     parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
  158.     parser.add_argument("--latent_tiled_size", type=int, default=96)
  159.     parser.add_argument("--latent_tiled_overlap", type=int, default=32)
  160.  
  161.     args = parser.parse_args()
  162.  
  163.     # initialize the model
  164.     model = OSEDiff_test(args)
  165.  
  166.     # get ram model
  167.     DAPE = ram(pretrained=args.ram_path,
  168.             pretrained_condition=args.ram_ft_path,
  169.             image_size=384,
  170.             vit='swin_l')
  171.     DAPE.eval()
  172.     DAPE.to("cuda")
  173.  
  174.     # weight type
  175.     weight_dtype = torch.float32
  176.     if args.mixed_precision == "fp16":
  177.         weight_dtype = torch.float16
  178.  
  179.     # set weight type
  180.     DAPE = DAPE.to(dtype=weight_dtype)
  181.    
  182.     # make the output dir
  183.     os.makedirs(args.output_dir, exist_ok=True)
  184.  
  185.     # Find all video files in the input directory using ffmpeg
  186.     input_video_paths = []
  187.     for file in os.listdir(args.input_image):
  188.         file_path = os.path.join(args.input_image, file)
  189.         try:
  190.             # Use ffprobe to check if the file is a valid video
  191.             probe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=codec_type',
  192.                        '-of', 'default=noprint_wrappers=1:nokey=1', file_path]
  193.             process = subprocess.Popen(probe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
  194.             output, error = process.communicate()
  195.             if process.returncode == 0 and output.strip() == 'video':
  196.                 input_video_paths.append(file_path)
  197.         except Exception as e:
  198.             print(f"Error checking file {file_path}: {e}")
  199.  
  200.     # Process each video file and store processing times
  201.     processing_times = []
  202.     for input_video_path in input_video_paths:
  203.         print(f"Processing video: {input_video_path}")
  204.         process_video(args, model, DAPE, weight_dtype, input_video_path, processing_times)
  205.  
  206.     # Print processing times after all videos are processed
  207.     print("\nProcessing times:")
  208.     for input_video_path, processing_time in processing_times:
  209.         print(f"{input_video_path}: {processing_time:.2f} seconds")
  210.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement