Advertisement
zelenooki87

Real DRCT fp16 onnx inference script simplified

Oct 11th, 2024 (edited)
177
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.30 KB | Science | 0 0
  1. import argparse
  2. import os
  3. import cv2
  4. import numpy as np
  5. import onnxruntime
  6. from tqdm import tqdm
  7. from glob import glob
  8.  
  9. def main():
  10.     parser = argparse.ArgumentParser(description="Upscale images using an ONNX model.")
  11.     parser.add_argument('--model_path', type=str, default='Real-DRCT-GAN_Finetuned.onnx', help='Path to the ONNX model')
  12.     parser.add_argument('--input', type=str, default='input', help='Input folder with images')
  13.     parser.add_argument('--output', type=str, default='output', help='Output folder')
  14.     parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
  15.     parser.add_argument('--tile_size', type=int, default=512, help='Tile size for processing')
  16.     parser.add_argument('--tile_pad', type=int, default=32, help='Padding around tiles')
  17.     args = parser.parse_args()
  18.  
  19.     os.makedirs(args.output, exist_ok=True)
  20.     ort_session = onnxruntime.InferenceSession(args.model_path, providers=['CUDAExecutionProvider'])
  21.     input_name = ort_session.get_inputs()[0].name
  22.  
  23.     for image_path in tqdm(glob(os.path.join(args.input, '*')), desc="Processing images"):
  24.         process_image(image_path, ort_session, input_name, args)
  25.  
  26. def process_image(image_path, ort_session, input_name, args):
  27.     img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
  28.     output_img = tile_process(img, ort_session, input_name, args)
  29.     output_path = os.path.join(args.output, f"{os.path.splitext(os.path.basename(image_path))[0]}_upscaled.bmp")
  30.     cv2.imwrite(output_path, cv2.cvtColor((output_img * 255.0).astype(np.uint8), cv2.COLOR_BGR2RGB))
  31.  
  32. def tile_process(img, ort_session, input_name, args):
  33.     scale, tile_size, tile_pad = args.scale, args.tile_size, args.tile_pad
  34.     height, width = img.shape[:2]
  35.     output_img = np.zeros((height * scale, width * scale, 3), dtype=np.float32)
  36.     img = np.transpose(img[..., ::-1], (2, 0, 1)).astype(np.float16)
  37.  
  38.     for y in range(0, height, tile_size):
  39.         for x in range(0, width, tile_size):
  40.             x_start, x_end = x, min(x + tile_size, width)
  41.             y_start, y_end = y, min(y + tile_size, height)
  42.  
  43.             x_pad_start, x_pad_end = max(x_start - tile_pad, 0), min(x_end + tile_pad, width)
  44.             y_pad_start, y_pad_end = max(y_start - tile_pad, 0), min(y_end + tile_pad, height)
  45.  
  46.             input_tile = img[:, y_pad_start:y_pad_end, x_pad_start:x_pad_end]
  47.             input_tile = np.pad(input_tile, ((0, 0), (0, (-input_tile.shape[1]) % 16), (0, (-input_tile.shape[2]) % 16)), 'reflect')[None, ...]
  48.            
  49.             output_tile = ort_session.run(None, {input_name: input_tile})[0][0]
  50.             output_tile = np.clip(output_tile, 0, 1).transpose(1, 2, 0)
  51.  
  52.             out_x_start, out_x_end = x_start * scale, x_end * scale
  53.             out_y_start, out_y_end = y_start * scale, y_end * scale
  54.  
  55.             out_tile_x_start = (x_start - x_pad_start) * scale
  56.             out_tile_x_end = out_tile_x_start + (x_end - x_start) * scale
  57.             out_tile_y_start = (y_start - y_pad_start) * scale
  58.             out_tile_y_end = out_tile_y_start + (y_end - y_start) * scale
  59.  
  60.             output_img[out_y_start:out_y_end, out_x_start:out_x_end] = output_tile[out_tile_y_start:out_tile_y_end, out_tile_x_start:out_tile_x_end]
  61.  
  62.     return output_img
  63.  
  64. if __name__ == '__main__':
  65.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement