Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import os
- import cv2
- import numpy as np
- import onnxruntime
- from tqdm import tqdm
- from glob import glob
- def main():
- parser = argparse.ArgumentParser(description="Upscale images using an ONNX model.")
- parser.add_argument('--model_path', type=str, default='Real-DRCT-GAN_Finetuned.onnx', help='Path to the ONNX model')
- parser.add_argument('--input', type=str, default='input', help='Input folder with images')
- parser.add_argument('--output', type=str, default='output', help='Output folder')
- parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
- parser.add_argument('--tile_size', type=int, default=512, help='Tile size for processing')
- parser.add_argument('--tile_pad', type=int, default=32, help='Padding around tiles')
- args = parser.parse_args()
- os.makedirs(args.output, exist_ok=True)
- ort_session = onnxruntime.InferenceSession(args.model_path, providers=['CUDAExecutionProvider'])
- input_name = ort_session.get_inputs()[0].name
- for image_path in tqdm(glob(os.path.join(args.input, '*')), desc="Processing images"):
- process_image(image_path, ort_session, input_name, args)
- def process_image(image_path, ort_session, input_name, args):
- img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
- output_img = tile_process(img, ort_session, input_name, args)
- output_path = os.path.join(args.output, f"{os.path.splitext(os.path.basename(image_path))[0]}_upscaled.bmp")
- cv2.imwrite(output_path, cv2.cvtColor((output_img * 255.0).astype(np.uint8), cv2.COLOR_BGR2RGB))
- def tile_process(img, ort_session, input_name, args):
- scale, tile_size, tile_pad = args.scale, args.tile_size, args.tile_pad
- height, width = img.shape[:2]
- output_img = np.zeros((height * scale, width * scale, 3), dtype=np.float32)
- img = np.transpose(img[..., ::-1], (2, 0, 1)).astype(np.float16)
- for y in range(0, height, tile_size):
- for x in range(0, width, tile_size):
- x_start, x_end = x, min(x + tile_size, width)
- y_start, y_end = y, min(y + tile_size, height)
- x_pad_start, x_pad_end = max(x_start - tile_pad, 0), min(x_end + tile_pad, width)
- y_pad_start, y_pad_end = max(y_start - tile_pad, 0), min(y_end + tile_pad, height)
- input_tile = img[:, y_pad_start:y_pad_end, x_pad_start:x_pad_end]
- input_tile = np.pad(input_tile, ((0, 0), (0, (-input_tile.shape[1]) % 16), (0, (-input_tile.shape[2]) % 16)), 'reflect')[None, ...]
- output_tile = ort_session.run(None, {input_name: input_tile})[0][0]
- output_tile = np.clip(output_tile, 0, 1).transpose(1, 2, 0)
- out_x_start, out_x_end = x_start * scale, x_end * scale
- out_y_start, out_y_end = y_start * scale, y_end * scale
- out_tile_x_start = (x_start - x_pad_start) * scale
- out_tile_x_end = out_tile_x_start + (x_end - x_start) * scale
- out_tile_y_start = (y_start - y_pad_start) * scale
- out_tile_y_end = out_tile_y_start + (y_end - y_start) * scale
- output_img[out_y_start:out_y_end, out_x_start:out_x_end] = output_tile[out_tile_y_start:out_tile_y_end, out_tile_x_start:out_tile_x_end]
- return output_img
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement