Advertisement
zelenooki87

osediff.py, NO TILING

Sep 4th, 2024
45
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.82 KB | Science | 0 0
  1. import os
  2. import sys
  3. sys.path.append(os.getcwd())
  4. import yaml
  5. import copy
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from transformers import AutoTokenizer, CLIPTextModel
  10. from diffusers import DDPMScheduler
  11. from models.autoencoder_kl import AutoencoderKL
  12. from models.unet_2d_condition import UNet2DConditionModel
  13. from peft import LoraConfig
  14.  
  15. from my_utils.vaehook import VAEHook, perfcount
  16.  
  17. class OSEDiff_test(torch.nn.Module):
  18.     def __init__(self, args):
  19.         super().__init__()
  20.  
  21.         self.args = args
  22.         self.device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  23.         self.tokenizer = AutoTokenizer.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="tokenizer")
  24.         self.text_encoder = CLIPTextModel.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="text_encoder")
  25.         self.noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
  26.         self.noise_scheduler.set_timesteps(1, device="cuda")
  27.         self.vae = AutoencoderKL.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="vae")
  28.         self.unet = UNet2DConditionModel.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="unet")
  29.  
  30.         # UKLONJEN RED: self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
  31.  
  32.         self.weight_dtype = torch.float32
  33.         if args.mixed_precision == "fp16":
  34.             self.weight_dtype = torch.float16
  35.  
  36.         osediff = torch.load(args.osediff_path)
  37.         self.load_ckpt(osediff)
  38.  
  39.         # merge lora
  40.         if self.args.merge_and_unload_lora:
  41.             print(f'===> MERGE LORA <===')
  42.             self.vae = self.vae.merge_and_unload()
  43.             self.unet = self.unet.merge_and_unload()
  44.  
  45.         self.unet.to("cuda", dtype=self.weight_dtype)
  46.         self.vae.to("cuda", dtype=self.weight_dtype)
  47.         self.text_encoder.to("cuda", dtype=self.weight_dtype)
  48.         self.timesteps = torch.tensor([999], device="cuda").long()
  49.         self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.cuda()
  50.  
  51.        
  52.  
  53.     def load_ckpt(self, model):
  54.         # load unet lora
  55.         lora_conf_encoder = LoraConfig(r=model["rank_unet"], init_lora_weights="gaussian", target_modules=model["unet_lora_encoder_modules"])
  56.         lora_conf_decoder = LoraConfig(r=model["rank_unet"], init_lora_weights="gaussian", target_modules=model["unet_lora_decoder_modules"])
  57.         lora_conf_others = LoraConfig(r=model["rank_unet"], init_lora_weights="gaussian", target_modules=model["unet_lora_others_modules"])
  58.         self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
  59.         self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
  60.         self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
  61.         for n, p in self.unet.named_parameters():
  62.             if "lora" in n or "conv_in" in n:
  63.                 p.data.copy_(model["state_dict_unet"][n])
  64.         self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
  65.  
  66.         # load vae lora
  67.         vae_lora_conf_encoder = LoraConfig(r=model["rank_vae"], init_lora_weights="gaussian", target_modules=model["vae_lora_encoder_modules"])
  68.         self.vae.add_adapter(vae_lora_conf_encoder, adapter_name="default_encoder")
  69.         for n, p in self.vae.named_parameters():
  70.             if "lora" in n:
  71.                 p.data.copy_(model["state_dict_vae"][n])
  72.         self.vae.set_adapter(['default_encoder'])
  73.  
  74.     def encode_prompt(self, prompt_batch):
  75.         prompt_embeds_list = []
  76.         with torch.no_grad():
  77.             for caption in prompt_batch:
  78.                 text_input_ids = self.tokenizer(
  79.                     caption, max_length=self.tokenizer.model_max_length,
  80.                     padding="max_length", truncation=True, return_tensors="pt"
  81.                 ).input_ids
  82.                 prompt_embeds = self.text_encoder(
  83.                     text_input_ids.to(self.text_encoder.device),
  84.                 )[0]
  85.                 prompt_embeds_list.append(prompt_embeds)
  86.         prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
  87.         return prompt_embeds
  88.  
  89.     @perfcount
  90.     @torch.no_grad()
  91.     def forward(self, lq, prompt):
  92.  
  93.         prompt_embeds = self.encode_prompt([prompt])
  94.         lq_latent = self.vae.encode(lq.to(self.weight_dtype)).latent_dist.sample() * self.vae.config.scaling_factor
  95.  
  96.         # UKLONJENA LOGIKA TAJLINGA
  97.         model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=prompt_embeds).sample # Direktno prosleđivanje celog latentnog prostora
  98.  
  99.         x_denoised = self.noise_scheduler.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample
  100.         output_image = (self.vae.decode(x_denoised.to(self.weight_dtype) / self.vae.config.scaling_factor).sample).clamp(-1, 1)
  101.  
  102.         return output_image
  103.  
  104.     def _init_tiled_vae(self,
  105.             encoder_tile_size = 256,
  106.             decoder_tile_size = 256,
  107.             fast_decoder = False,
  108.             fast_encoder = False,
  109.             color_fix = False,
  110.             vae_to_gpu = True):
  111.         # save original forward (only once)
  112.         if not hasattr(self.vae.encoder, 'original_forward'):
  113.             setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)
  114.         if not hasattr(self.vae.decoder, 'original_forward'):
  115.             setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)
  116.  
  117.         encoder = self.vae.encoder
  118.         decoder = self.vae.decoder
  119.  
  120.         self.vae.encoder.forward = VAEHook(
  121.             encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
  122.         self.vae.decoder.forward = VAEHook(
  123.             decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
  124.  
  125.     def _gaussian_weights(self, tile_width, tile_height, nbatches):
  126.         """Generates a gaussian mask of weights for tile contributions"""
  127.         from numpy import pi, exp, sqrt
  128.         import numpy as np
  129.  
  130.         latent_width = tile_width
  131.         latent_height = tile_height
  132.  
  133.         var = 0.01
  134.         midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1
  135.         x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
  136.         midpoint = latent_height / 2
  137.         y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
  138.  
  139.         weights = np.outer(y_probs, x_probs)
  140.         return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement