Advertisement
zelenooki87

REAL DRCT Gan, fp16 inference onnx

Jun 24th, 2024 (edited)
262
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.46 KB | Science | 0 0
  1. import argparse
  2. import cv2
  3. import glob
  4. import numpy as np
  5. import os
  6. import onnxruntime
  7. import time
  8. import math
  9. from tqdm import tqdm
  10.  
  11. def main():
  12.     parser = argparse.ArgumentParser()
  13.     parser.add_argument('--model_path', type=str, default='model.onnx', help='Path to the ONNX model')
  14.     parser.add_argument('--input', type=str, default='input', help='Input folder with images')
  15.     parser.add_argument('--output', type=str, default='output', help='Output folder')
  16.     parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
  17.     parser.add_argument('--tile_size', type=int, default=512, help='Tile size for processing')
  18.     parser.add_argument('--tile_pad', type=int, default=32, help='Padding around tiles')
  19.     args = parser.parse_args()
  20.  
  21.     # Load the ONNX model with CUDA Execution Provider
  22.     ort_session = onnxruntime.InferenceSession(args.model_path, providers=['CUDAExecutionProvider'])
  23.     input_name = ort_session.get_inputs()[0].name
  24.  
  25.     # Create output folder if it doesn't exist
  26.     os.makedirs(args.output, exist_ok=True)
  27.  
  28.     # Process each image in the input folder
  29.     for image_path in tqdm(glob.glob(os.path.join(args.input, '*')), desc="Processing images", unit="image"):
  30.         # Load image and normalize
  31.         img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
  32.         original_height, original_width = img.shape[:2]
  33.  
  34.         # Upscale image using tiling
  35.         output_img = tile_process(img, ort_session, input_name, args.scale, args.tile_size, args.tile_pad)
  36.  
  37.         # Convert to uint8 and save the upscaled image
  38.         output_img = (output_img * 255.0).round().astype(np.uint8)
  39.         output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
  40.  
  41.         # Construct output filename with suffix and .png extension
  42.         filename, _ = os.path.splitext(os.path.basename(image_path))
  43.         output_filename = f"{filename}_REAL_GAN_DRCT.png"
  44.         cv2.imwrite(os.path.join(args.output, output_filename), output_img)
  45.  
  46.  
  47. def tile_process(img, ort_session, input_name, scale, tile_size, tile_pad):
  48.     """Processes the image in tiles to avoid OOM errors."""
  49.     height, width = img.shape[:2]
  50.     output_height = height * scale
  51.     output_width = width * scale
  52.     output_shape = (output_height, output_width, 3)
  53.  
  54.     # Start with black image
  55.     output_img = np.zeros(output_shape, dtype=np.float32)
  56.  
  57.     # Calculate number of tiles
  58.     tiles_x = math.ceil(width / tile_size)
  59.     tiles_y = math.ceil(height / tile_size)
  60.  
  61.     # Loop over all tiles
  62.     for y in range(tiles_y):
  63.         for x in range(tiles_x):
  64.             # Extract tile from input image
  65.             ofs_x = x * tile_size
  66.             ofs_y = y * tile_size
  67.             input_start_x = ofs_x
  68.             input_end_x = min(ofs_x + tile_size, width)
  69.             input_start_y = ofs_y
  70.             input_end_y = min(ofs_y + tile_size, height)
  71.  
  72.             # Input tile area on total image with padding
  73.             input_start_x_pad = max(input_start_x - tile_pad, 0)
  74.             input_end_x_pad = min(input_end_x + tile_pad, width)
  75.             input_start_y_pad = max(input_start_y - tile_pad, 0)
  76.             input_end_y_pad = min(input_end_y + tile_pad, height)
  77.  
  78.             # Input tile dimensions
  79.             input_tile_width = input_end_x - input_start_x
  80.             input_tile_height = input_end_y - input_start_y
  81.             tile_idx = y * tiles_x + x + 1
  82.             input_tile = img[input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad, :]
  83.  
  84.             # Pad tile to be divisible by scaling factor
  85.             input_tile = pad_image(input_tile, 16)
  86.  
  87.             # Convert to BGR, transpose to CHW, and add batch dimension
  88.             input_tile = np.transpose(input_tile[:, :, [2, 1, 0]], (2, 0, 1))
  89.             input_tile = np.expand_dims(input_tile, axis=0).astype(np.float16)
  90.  
  91.             # Run inference
  92.             output_tile = ort_session.run(None, {input_name: input_tile})[0]
  93.  
  94.             # Post-process the output tile
  95.             output_tile = np.clip(output_tile, 0, 1)
  96.             output_tile = np.transpose(output_tile[0, :, :, :], (1, 2, 0))
  97.  
  98.             # Output tile area on total image
  99.             output_start_x = input_start_x * scale
  100.             output_end_x = input_end_x * scale
  101.             output_start_y = input_start_y * scale
  102.             output_end_y = input_end_y * scale
  103.  
  104.             # Output tile area without padding
  105.             output_start_x_tile = (input_start_x - input_start_x_pad) * scale
  106.             output_end_x_tile = output_start_x_tile + input_tile_width * scale
  107.             output_start_y_tile = (input_start_y - input_start_y_pad) * scale
  108.             output_end_y_tile = output_start_y_tile + input_tile_height * scale
  109.  
  110.             # Put tile into output image
  111.             output_img[output_start_y:output_end_y, output_start_x:output_end_x, :] = output_tile[
  112.                 output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile, :
  113.             ]
  114.  
  115.             print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  116.  
  117.     return output_img
  118.  
  119.  
  120. def pad_image(img, factor):
  121.     """Pads the image to be divisible by the given factor using reflection padding."""
  122.     height, width = img.shape[:2]
  123.     pad_height = (factor - (height % factor)) % factor
  124.     pad_width = (factor - (width % factor)) % factor
  125.     return cv2.copyMakeBorder(img, 0, pad_height, 0, pad_width, cv2.BORDER_REFLECT_101)
  126.  
  127.  
  128. if __name__ == '__main__':
  129.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement