Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import sys
- sys.path.append(os.getcwd())
- import yaml
- import copy
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from transformers import AutoTokenizer, CLIPTextModel
- from diffusers import DDPMScheduler
- from models.autoencoder_kl import AutoencoderKL
- from models.unet_2d_condition import UNet2DConditionModel
- from peft import LoraConfig
- from my_utils.vaehook import VAEHook, perfcount
- class OSEDiff_test(torch.nn.Module):
- def __init__(self, args):
- super().__init__()
- self.args = args
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- self.tokenizer = AutoTokenizer.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="tokenizer")
- self.text_encoder = CLIPTextModel.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="text_encoder")
- self.noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
- self.noise_scheduler.set_timesteps(1, device="cuda")
- self.vae = AutoencoderKL.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="vae")
- self.unet = UNet2DConditionModel.from_pretrained(self.args.pretrained_model_name_or_path, subfolder="unet")
- # UKLONJEN RED: self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
- self.weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- self.weight_dtype = torch.float16
- osediff = torch.load(args.osediff_path)
- self.load_ckpt(osediff)
- # merge lora
- if self.args.merge_and_unload_lora:
- print(f'===> MERGE LORA <===')
- self.vae = self.vae.merge_and_unload()
- self.unet = self.unet.merge_and_unload()
- self.unet.to("cuda", dtype=self.weight_dtype)
- self.vae.to("cuda", dtype=self.weight_dtype)
- self.text_encoder.to("cuda", dtype=self.weight_dtype)
- self.timesteps = torch.tensor([999], device="cuda").long()
- self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.cuda()
- def load_ckpt(self, model):
- # load unet lora
- lora_conf_encoder = LoraConfig(r=model["rank_unet"], init_lora_weights="gaussian", target_modules=model["unet_lora_encoder_modules"])
- lora_conf_decoder = LoraConfig(r=model["rank_unet"], init_lora_weights="gaussian", target_modules=model["unet_lora_decoder_modules"])
- lora_conf_others = LoraConfig(r=model["rank_unet"], init_lora_weights="gaussian", target_modules=model["unet_lora_others_modules"])
- self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
- self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
- self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
- for n, p in self.unet.named_parameters():
- if "lora" in n or "conv_in" in n:
- p.data.copy_(model["state_dict_unet"][n])
- self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
- # load vae lora
- vae_lora_conf_encoder = LoraConfig(r=model["rank_vae"], init_lora_weights="gaussian", target_modules=model["vae_lora_encoder_modules"])
- self.vae.add_adapter(vae_lora_conf_encoder, adapter_name="default_encoder")
- for n, p in self.vae.named_parameters():
- if "lora" in n:
- p.data.copy_(model["state_dict_vae"][n])
- self.vae.set_adapter(['default_encoder'])
- def encode_prompt(self, prompt_batch):
- prompt_embeds_list = []
- with torch.no_grad():
- for caption in prompt_batch:
- text_input_ids = self.tokenizer(
- caption, max_length=self.tokenizer.model_max_length,
- padding="max_length", truncation=True, return_tensors="pt"
- ).input_ids
- prompt_embeds = self.text_encoder(
- text_input_ids.to(self.text_encoder.device),
- )[0]
- prompt_embeds_list.append(prompt_embeds)
- prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
- return prompt_embeds
- @perfcount
- @torch.no_grad()
- def forward(self, lq, prompt):
- prompt_embeds = self.encode_prompt([prompt])
- lq_latent = self.vae.encode(lq.to(self.weight_dtype)).latent_dist.sample() * self.vae.config.scaling_factor
- # UKLONJENA LOGIKA TAJLINGA
- model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=prompt_embeds).sample # Direktno prosleđivanje celog latentnog prostora
- x_denoised = self.noise_scheduler.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample
- output_image = (self.vae.decode(x_denoised.to(self.weight_dtype) / self.vae.config.scaling_factor).sample).clamp(-1, 1)
- return output_image
- def _init_tiled_vae(self,
- encoder_tile_size = 256,
- decoder_tile_size = 256,
- fast_decoder = False,
- fast_encoder = False,
- color_fix = False,
- vae_to_gpu = True):
- # save original forward (only once)
- if not hasattr(self.vae.encoder, 'original_forward'):
- setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)
- if not hasattr(self.vae.decoder, 'original_forward'):
- setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)
- encoder = self.vae.encoder
- decoder = self.vae.decoder
- self.vae.encoder.forward = VAEHook(
- encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
- self.vae.decoder.forward = VAEHook(
- decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
- def _gaussian_weights(self, tile_width, tile_height, nbatches):
- """Generates a gaussian mask of weights for tile contributions"""
- from numpy import pi, exp, sqrt
- import numpy as np
- latent_width = tile_width
- latent_height = tile_height
- var = 0.01
- midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
- x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
- midpoint = latent_height / 2
- y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
- weights = np.outer(y_probs, x_probs)
- return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement