Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import sys
- sys.path.append(os.getcwd())
- import glob
- import argparse
- import torch
- from torchvision import transforms
- import torchvision.transforms.functional as F
- import numpy as np
- from PIL import Image
- import subprocess
- import time
- from osediff import OSEDiff_test
- from my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
- from ram.models.ram_lora import ram
- from ram import inference_ram as inference
- tensor_transforms = transforms.Compose([
- transforms.ToTensor(),
- ])
- ram_transforms = transforms.Compose([
- transforms.Resize((384, 384)),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- def get_validation_prompt(args, image, model, device='cuda'):
- validation_prompt = ""
- lq = tensor_transforms(image).unsqueeze(0).to(device)
- lq = ram_transforms(lq).to(dtype=weight_dtype)
- captions = inference(lq, model)
- validation_prompt = f"{captions[0]}, {args.prompt},"
- return validation_prompt
- def get_average_fps(video_path):
- """
- Gets the average FPS of a video using ffprobe.
- Args:
- video_path (str): Path to the video file.
- Returns:
- float: The average FPS of the video.
- """
- cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=avg_frame_rate',
- '-of', 'default=noprint_wrappers=1:nokey=1', video_path]
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
- output, error = process.communicate()
- if process.returncode != 0:
- raise RuntimeError(f'ffprobe error: {error}')
- num, den = map(int, output.strip().split('/'))
- return num / den
- def process_video(args, model, DAPE, weight_dtype, input_video_path, processing_times):
- # Extract frames and get input framerate
- temp_dir = os.path.join(args.output_dir, "temp_frames", os.path.splitext(os.path.basename(input_video_path))[0])
- os.makedirs(temp_dir, exist_ok=True)
- # Use ffprobe to get average framerate
- avg_framerate = get_average_fps(input_video_path)
- # Extract frames in BMP format
- subprocess.run([
- 'ffmpeg', '-i', input_video_path, '-vf', f'fps={avg_framerate}', f'{temp_dir}/frame%04d.bmp'
- ])
- # Process each frame
- image_names = sorted(glob.glob(f'{temp_dir}/*.bmp'))
- start_time = time.time()
- for image_name in image_names:
- input_image = Image.open(image_name).convert('RGB')
- ori_width, ori_height = input_image.size
- rscale = args.upscale
- resize_flag = False
- if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
- scale = (args.process_size//rscale)/min(ori_width, ori_height)
- input_image = input_image.resize((int(scale*ori_width), int(scale*ori_height)))
- resize_flag = True
- input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
- new_width = input_image.width - input_image.width % 8
- new_height = input_image.height - input_image.height % 8
- input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
- bname = os.path.basename(image_name)
- # get caption
- validation_prompt = get_validation_prompt(args, input_image, DAPE)
- print(f"process {image_name}, tag: {validation_prompt}".encode('utf-8'))
- # translate the image
- with torch.no_grad():
- lq = F.to_tensor(input_image).unsqueeze(0).cuda()*2-1
- output_image = model(lq, prompt=validation_prompt)
- output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
- if args.align_method == 'adain':
- output_pil = adain_color_fix(target=output_pil, source=input_image)
- elif args.align_method == 'wavelet':
- output_pil = wavelet_color_fix(target=output_pil, source=input_image)
- else:
- pass
- if resize_flag:
- output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
- output_pil.save(os.path.join(temp_dir, bname))
- # Combine frames back into video using the detected average framerate
- output_video_path = os.path.join(args.output_dir, os.path.splitext(os.path.basename(input_video_path))[0] + "_upscaled.mov")
- subprocess.run([
- 'ffmpeg',
- '-y', # Overwrite output files without asking
- '-framerate', str(avg_framerate),
- '-i', f'{temp_dir}/frame%04d.bmp',
- '-i', input_video_path,
- '-map', '0:v',
- '-map', '1:a',
- '-c:v', 'ffv1',
- '-pix_fmt', 'rgb48',
- '-metadata:s:v:0', 'encoder=FFV1',
- '-level', '3',
- '-g', '1',
- '-slices', '24',
- '-slicecrc', '1',
- output_video_path
- ])
- # Clean up temporary frames
- subprocess.run(['rm', '-r', temp_dir])
- end_time = time.time()
- processing_time = end_time - start_time
- processing_times.append((input_video_path, processing_time))
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_image', '-i', type=str, default='input', help='path to the input image or video')
- parser.add_argument('--output_dir', '-o', type=str, default='output', help='the directory to save the output')
- parser.add_argument('--pretrained_model_name_or_path', type=str, default='preset/models/stable-diffusion-2-1-base', help='sd model path')
- parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
- parser.add_argument("--process_size", type=int, default=512)
- parser.add_argument("--upscale", type=int, default=4)
- parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
- parser.add_argument("--osediff_path", type=str, default='preset/models/osediff.pkl')
- parser.add_argument('--prompt', type=str, default='', help='user prompts')
- parser.add_argument('--ram_path', type=str, default='preset/models/ram_swin_large_14m.pth')
- parser.add_argument('--ram_ft_path', type=str, default='preset/models/DAPE.pth')
- parser.add_argument('--save_prompts', type=bool, default=True)
- # precision setting
- parser.add_argument("--mixed_precision", type=str, choices=['fp16', 'fp32'], default="fp16")
- # merge lora
- parser.add_argument("--merge_and_unload_lora", default=False) # merge lora weights before inference
- # tile setting
- parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
- parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
- parser.add_argument("--latent_tiled_size", type=int, default=96)
- parser.add_argument("--latent_tiled_overlap", type=int, default=32)
- args = parser.parse_args()
- # initialize the model
- model = OSEDiff_test(args)
- # get ram model
- DAPE = ram(pretrained=args.ram_path,
- pretrained_condition=args.ram_ft_path,
- image_size=384,
- vit='swin_l')
- DAPE.eval()
- DAPE.to("cuda")
- # weight type
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- # set weight type
- DAPE = DAPE.to(dtype=weight_dtype)
- # make the output dir
- os.makedirs(args.output_dir, exist_ok=True)
- # Find all video files in the input directory using ffmpeg
- input_video_paths = []
- for file in os.listdir(args.input_image):
- file_path = os.path.join(args.input_image, file)
- try:
- # Use ffprobe to check if the file is a valid video
- probe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=codec_type',
- '-of', 'default=noprint_wrappers=1:nokey=1', file_path]
- process = subprocess.Popen(probe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
- output, error = process.communicate()
- if process.returncode == 0 and output.strip() == 'video':
- input_video_paths.append(file_path)
- except Exception as e:
- print(f"Error checking file {file_path}: {e}")
- # Process each video file and store processing times
- processing_times = []
- for input_video_path in input_video_paths:
- print(f"Processing video: {input_video_path}")
- process_video(args, model, DAPE, weight_dtype, input_video_path, processing_times)
- # Print processing times after all videos are processed
- print("\nProcessing times:")
- for input_video_path, processing_time in processing_times:
- print(f"{input_video_path}: {processing_time:.2f} seconds")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement