Advertisement
zelenooki87

test_DLoRAL.py

Jul 8th, 2025
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.71 KB | Science | 0 0
  1. import os
  2. import argparse
  3. import time
  4.  
  5. import numpy as np
  6. from PIL import Image
  7. import torch
  8. from torchvision import transforms
  9. import torchvision.transforms.functional as F
  10. import sys
  11.  
  12. sys.path.append(os.getcwd())
  13. from src.DLoRAL_model import Generator_eval
  14. from src.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
  15. import PIL.Image
  16. import math
  17. PIL.Image.MAX_IMAGE_PIXELS = 933120000
  18.  
  19. import glob
  20. import torch
  21. import gc
  22. import cv2
  23. from ram.models.ram_lora import ram
  24. from ram import inference_ram as inference
  25.  
  26. tensor_transforms = transforms.Compose([
  27.     transforms.ToTensor(),
  28. ])
  29.  
  30. ram_transforms = transforms.Compose([
  31.     transforms.Resize((384, 384)),
  32.     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  33. ])
  34.  
  35. center_crop = transforms.CenterCrop(128)
  36. center_crop_gt = transforms.CenterCrop(512)
  37.  
  38.  
  39. def get_validation_prompt(args, image, model, device='cuda'):
  40.     validation_prompt = ""
  41.     lq = tensor_transforms(image).unsqueeze(0).to(device)
  42.     lq = ram_transforms(lq).to(dtype=weight_dtype)
  43.     captions = inference(lq, model)
  44.     validation_prompt = f"{captions[0]}, {args.prompt},"
  45.  
  46.     return validation_prompt
  47.  
  48.  
  49. def extract_frames(video_path):
  50.     video_capture = cv2.VideoCapture(video_path)
  51.  
  52.     frame_number = 0
  53.     success, frame = video_capture.read()
  54.     frame_images = []
  55.  
  56.     # Loop through frames
  57.     while success:
  58.         # Save each frame as an image
  59.         frame_dir = '{}'.format(video_path.split('.mp4')[0])
  60.         if not os.path.exists(frame_dir):
  61.             os.makedirs(frame_dir)
  62.         frame_filename = "frame_{:04d}.png".format(frame_number)
  63.         cv2.imwrite('{}/{}'.format(frame_dir, frame_filename), frame)
  64.         print("Writing frame to {}/{}".format(frame_dir, frame_filename))
  65.  
  66.         frame_images.append(os.path.join(frame_dir, frame_filename))
  67.  
  68.         # Move to the next frame
  69.         success, frame = video_capture.read()
  70.         frame_number += 1
  71.  
  72.     video_capture.release()
  73.     print(f"Frames extracted from {video_path} successfully!")
  74.  
  75.     return frame_images
  76.  
  77.  
  78. def process_video_directory(input_directory):
  79.     video_files = glob.glob(os.path.join(input_directory, "*.mp4"))
  80.     all_video_data = []
  81.  
  82.     # Process each video and extract frames
  83.     for video_file in video_files:
  84.         print(f"Processing video: {video_file}")
  85.  
  86.         # Extract frames and get their names
  87.         frame_images = extract_frames(video_file)
  88.  
  89.         # Extract video name (without extension) to create consistent naming
  90.         video_name = os.path.basename(video_file).split('.')[0]  # Extract the name without .mp4 extension
  91.  
  92.         all_video_data.append((video_name, frame_images))
  93.  
  94.     return all_video_data
  95.  
  96. def compute_frame_difference_mask(frames):
  97.     ambi_matrix = frames.var(dim=0)
  98.     threshold = ambi_matrix.mean().item()
  99.     mask_id = torch.where(ambi_matrix >= threshold, ambi_matrix, torch.zeros_like(ambi_matrix))
  100.     frame_mask = torch.where(mask_id == 0, mask_id, torch.ones_like(mask_id))
  101.     return frame_mask
  102.  
  103. def pil_center_crop(image, target_size):
  104.     """
  105.    Perform center cropping on a PIL Image.
  106.    Args:
  107.        image: PIL Image object
  108.        target_size: Target dimensions (width, height)
  109.    """
  110.     width, height = image.size
  111.     target_width, target_height = target_size
  112.  
  113.    # Calculate the top-left coordinates
  114.     left = (width - target_width) // 2
  115.     upper = (height - target_height) // 2
  116.  
  117.     # Calculate the top-left coordinates
  118.     right = left + target_width
  119.     lower = upper + target_height
  120.  
  121.     # Perform cropping
  122.     return image.crop((left, upper, right, lower))
  123.  
  124. if __name__ == "__main__":
  125.     parser = argparse.ArgumentParser()
  126.     parser.add_argument('--input_image', '-i', type=str, default=None, help='path to the input image')
  127.     parser.add_argument('--output_dir', '-o', type=str, default=None, help='the directory to save the output')
  128.     parser.add_argument('--pretrained_path', type=str, default=None, help='path to a model state dict to be used')
  129.     parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
  130.     parser.add_argument("--process_size", type=int, default=512)
  131.     parser.add_argument("--upscale", type=int, default=4)
  132.     parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
  133.     parser.add_argument("--pretrained_model_name_or_path", type=str, default='preset_models/stable-diffusion-2-1-base')
  134.     parser.add_argument("--pretrained_model_path", type=str, default='preset_models/stable-diffusion-2-1-base')
  135.     parser.add_argument('--prompt', type=str, default='', help='user prompts')
  136.     parser.add_argument('--ram_path', type=str, default=None)
  137.     parser.add_argument('--ram_ft_path', type=str, default=None)
  138.     parser.add_argument('--save_prompts', type=bool, default=True)
  139.     # tile setting
  140.     parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
  141.     parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
  142.     parser.add_argument("--latent_tiled_size", type=int, default=96)
  143.     parser.add_argument("--latent_tiled_overlap", type=int, default=32)
  144.     # precision setting
  145.     parser.add_argument("--mixed_precision", type=str, default="fp16")
  146.     # merge lora
  147.     parser.add_argument("--merge_and_unload_lora", default=False)
  148.     # stages
  149.     parser.add_argument("--stages", type=int, default=None)
  150.     parser.add_argument("--load_cfr", action="store_true", )
  151.  
  152.     args = parser.parse_args()
  153.  
  154.     # initialize the model
  155.     model = Generator_eval(args)
  156.     model.set_eval()
  157.  
  158.     if os.path.isdir(args.input_image):
  159.         all_video_data = process_video_directory(args.input_image)
  160.     else:
  161.         # Handle single video case (if input is a single video file)
  162.         all_video_data = [(os.path.basename(args.input_image).split('.')[0], extract_frames(args.input_image))]
  163.  
  164.     # get ram model
  165.     DAPE = ram(pretrained=args.ram_path,
  166.                pretrained_condition=args.ram_ft_path,
  167.                image_size=384,
  168.                vit='swin_l')
  169.     DAPE.eval()
  170.     DAPE.to("cuda")
  171.  
  172.     # weight type
  173.     weight_dtype = torch.float32
  174.     if args.mixed_precision == "fp16":
  175.         weight_dtype = torch.float16
  176.     elif args.mixed_precision == "bf16":
  177.         weight_dtype = torch.bfloat16
  178.  
  179.     # set weight type
  180.     DAPE = DAPE.to(dtype=weight_dtype)
  181.     model.vae = model.vae.to(dtype=weight_dtype)
  182.     model.unet = model.unet.to(dtype=weight_dtype)
  183.     model.cfr_main_net = model.cfr_main_net.to(dtype=weight_dtype)
  184.  
  185.     if args.stages == 0:
  186.         model.unet.set_adapter(['default_encoder_consistency', 'default_decoder_consistency', 'default_others_consistency'])
  187.     else:
  188.         model.unet.set_adapter(['default_encoder_quality', 'default_decoder_quality',
  189.                                 'default_others_quality',
  190.                                 'default_encoder_consistency', 'default_decoder_consistency',
  191.                                 'default_others_consistency'])
  192.     if args.save_prompts:
  193.         txt_path = os.path.join(args.output_dir, 'txt')
  194.         os.makedirs(txt_path, exist_ok=True)
  195.  
  196.     # make the output dir
  197.     os.makedirs(args.output_dir, exist_ok=True)
  198.     print(f"There are {len(all_video_data)} videos to process.")
  199.     frame_num = 2
  200.     frame_overlap = 1
  201.  
  202.     for video_name, video_frame_images in all_video_data:
  203.         print(f"Processing frames for video: {video_name}")
  204.  
  205.         # Define the save path for the processed video
  206.         video_save_path = os.path.join(args.output_dir, video_name)
  207.         if not os.path.exists(video_save_path):
  208.             os.makedirs(video_save_path)
  209.  
  210.         # Initialize batches for storing input images and their grayscale versions
  211.         input_image_batch = []
  212.         input_image_gray_batch = []
  213.         bname_batch = []
  214.         prompt_batch = [] # Store prompts for each frame
  215.  
  216.         for image_name in video_frame_images:
  217.             print(image_name)
  218.             # make sure that the input image is a multiple of 8
  219.             input_image = Image.open(image_name).convert('RGB')
  220.             input_image_gray = input_image.convert('L')
  221.             ori_width, ori_height = input_image.size
  222.             rscale = args.upscale
  223.             resize_flag = False
  224.  
  225.             # If the image is smaller than the required size, scale it up
  226.             if ori_width < args.process_size // rscale or ori_height < args.process_size // rscale:
  227.                 scale = (args.process_size // rscale) / min(ori_width, ori_height)
  228.                 input_image = input_image.resize((int(scale * ori_width), int(scale * ori_height)))
  229.                 input_image_gray = input_image_gray.resize((int(scale * ori_width), int(scale * ori_height)))
  230.                 resize_flag = True
  231.            
  232.             # Upscale the image dimensions by the upscale factor
  233.             input_image = input_image.resize((input_image.size[0] * rscale, input_image.size[1] * rscale))
  234.             input_image_gray = input_image_gray.resize((input_image_gray.size[0] * rscale, input_image_gray.size[1] * rscale))
  235.  
  236.             # Adjust the image dimensions to make sure they are a multiple of 8
  237.             new_width = input_image.width - input_image.width % 8
  238.             new_height = input_image.height - input_image.height % 8
  239.             input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
  240.             input_image_gray = input_image_gray.resize((new_width, new_height), Image.LANCZOS)
  241.  
  242.             bname = os.path.basename(image_name)
  243.             bname_batch.append(bname)
  244.  
  245.             # Always generate a new prompt for each frame
  246.             validation_prompt = get_validation_prompt(args, input_image, DAPE)
  247.             if args.save_prompts:
  248.                 txt_save_path = f"{txt_path}/{bname.split('.')[0]}.txt"
  249.                 with open(txt_save_path, 'w', encoding='utf-8') as f:
  250.                     f.write(validation_prompt)
  251.                     f.close()
  252.  
  253.             print(f"process {image_name}, caption: {validation_prompt}".encode('utf-8'))
  254.             input_image_batch.append(input_image)
  255.             input_image_gray_batch.append(input_image_gray)
  256.             prompt_batch.append(validation_prompt) # Add the generated prompt to the batch
  257.  
  258.         for input_image_index in range(0, len(input_image_batch), (frame_num - frame_overlap)):
  259.             if input_image_index + frame_num - 1 >= len(input_image_batch):
  260.                 # Prevent out-of-bound issues for the last few frames
  261.                 end = len(input_image_batch) - input_image_index
  262.                 start = 0
  263.             else:
  264.                 start = 0
  265.                 end = frame_num
  266.  
  267.             # Collect the batch of frames to be processed
  268.             input_frames = []
  269.             input_frames_gray = []
  270.             for input_frame_index in range(start, end):
  271.                 real_idx = input_image_index + input_frame_index
  272.                 # Perform boundary checks to ensure indices are within range
  273.                 if real_idx < 0 or real_idx >= len(input_image_batch):
  274.                     continue
  275.  
  276.                 current_frame = transforms.functional.to_tensor(input_image_batch[real_idx])
  277.                 current_frame_gray = transforms.functional.to_tensor(input_image_gray_batch[real_idx])
  278.                 current_frame_gray = torch.nn.functional.interpolate(current_frame_gray.unsqueeze(0), scale_factor=0.125).squeeze(0)
  279.                 input_frames.append(current_frame)
  280.                 input_frames_gray.append(current_frame_gray)
  281.  
  282.             input_image_final = torch.stack(input_frames, dim=0)
  283.             input_image_gray_final = torch.stack(input_frames_gray, dim=0)
  284.  
  285.             uncertainty_map = []
  286.             if input_image_final.shape[0] == 1:
  287.                 break
  288.             for image_index in range(input_image_final.shape[0]):
  289.                 if image_index != 0:
  290.                     cur_img = input_image_gray_final[image_index]
  291.                     prev_img = input_image_gray_final[image_index - 1]
  292.  
  293.                     compute_frame = torch.stack([cur_img, prev_img])
  294.                     uncertainty_map_each = compute_frame_difference_mask(input_image_gray_final)
  295.                     uncertainty_map.append(uncertainty_map_each)
  296.  
  297.             uncertainty_map = torch.stack(uncertainty_map)
  298.  
  299.             # Get the prompt for the current frame
  300.             current_prompt_index = input_image_index + start
  301.             if current_prompt_index < len(prompt_batch):
  302.                 current_prompt = prompt_batch[current_prompt_index]
  303.             else:
  304.                 # Fallback if index is out of bounds (should not happen with correct logic)
  305.                 current_prompt = ""
  306.  
  307.             # Model input [b=1, t, c, h, w]
  308.             with torch.no_grad():
  309.                 # Normalize input image tensor to range [-1, 1]
  310.                 c_t = input_image_final.unsqueeze(0).cuda() * 2 - 1
  311.                 c_t = c_t.to(dtype=weight_dtype)
  312.                 output_image, _, _, _, _ = model(stages=args.stages, c_t=c_t, uncertainty_map=uncertainty_map.unsqueeze(0).cuda(), prompt=current_prompt, weight_dtype=weight_dtype)
  313.  
  314.             frame_t = output_image[0]  # shape: [c, h, w]
  315.             frame_t = (frame_t.cpu() * 0.5 + 0.5)  # Convert the frame back to range [0, 1]
  316.             output_pil = transforms.ToPILImage()(frame_t)
  317.  
  318.             # Find the index of the corresponding original image (start + output_index)
  319.             src_idx = input_image_index + start + 1
  320.             # Perform boundary check to ensure index is within valid range
  321.             if src_idx < 0 or src_idx >= len(input_image_batch):
  322.                 src_idx = max(0, min(src_idx, len(input_image_batch) - 1))
  323.  
  324.             # Use the corresponding frame for color/band correction
  325.             source_pil = input_image_batch[src_idx]
  326.  
  327.             if args.align_method == 'adain':
  328.                 output_pil = adain_color_fix(target=output_pil, source=source_pil)
  329.             elif args.align_method == 'wavelet':
  330.                 output_pil = wavelet_color_fix(target=output_pil, source=source_pil)
  331.             else:
  332.                 pass
  333.            
  334.             # If the image was resized earlier, resize it back to its original dimensions
  335.             if resize_flag:
  336.                 new_w = int(args.upscale * ori_width)
  337.                 new_h = int(args.upscale * ori_height)
  338.                 output_pil = output_pil.resize((new_w, new_h), Image.BICUBIC)
  339.  
  340.             global_frame_counter = src_idx
  341.             out_name = f"frame_{global_frame_counter:04d}.png"
  342.             out_path = f"{video_save_path}/{out_name}"
  343.  
  344.             output_pil.save(out_path)
  345.             print(f"Saving frame {global_frame_counter} to {out_path}")
  346.  
  347.             gc.collect()
  348.             torch.cuda.empty_cache()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement