Advertisement
MunchkinT

BG Replacement Script

Dec 10th, 2024 (edited)
37
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.16 KB | Software | 0 0
  1. import os
  2. import random
  3. import sys
  4. from typing import Sequence, Mapping, Any, Union
  5. import torch
  6.  
  7.  
  8. def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
  9.     """Returns the value at the given index of a sequence or mapping.
  10.  
  11.    If the object is a sequence (like list or string), returns the value at the given index.
  12.    If the object is a mapping (like a dictionary), returns the value at the index-th key.
  13.  
  14.    Some return a dictionary, in these cases, we look for the "results" key
  15.  
  16.    Args:
  17.        obj (Union[Sequence, Mapping]): The object to retrieve the value from.
  18.        index (int): The index of the value to retrieve.
  19.  
  20.    Returns:
  21.        Any: The value at the given index.
  22.  
  23.    Raises:
  24.        IndexError: If the index is out of bounds for the object and the object is not a mapping.
  25.    """
  26.     try:
  27.         return obj[index]
  28.     except KeyError:
  29.         return obj["result"][index]
  30.  
  31.  
  32. def find_path(name: str, path: str = None) -> str:
  33.     """
  34.    Recursively looks at parent folders starting from the given path until it finds the given name.
  35.    Returns the path as a Path object if found, or None otherwise.
  36.    """
  37.     # If no path is given, use the current working directory
  38.     if path is None:
  39.         path = os.getcwd()
  40.  
  41.     # Check if the current directory contains the name
  42.     if name in os.listdir(path):
  43.         path_name = os.path.join(path, name)
  44.         print(f"{name} found: {path_name}")
  45.         return path_name
  46.  
  47.     # Get the parent directory
  48.     parent_directory = os.path.dirname(path)
  49.  
  50.     # If the parent directory is the same as the current directory, we've reached the root and stop the search
  51.     if parent_directory == path:
  52.         return None
  53.  
  54.     # Recursively call the function with the parent directory
  55.     return find_path(name, parent_directory)
  56.  
  57.  
  58. def add_comfyui_directory_to_sys_path() -> None:
  59.     """
  60.    Add 'ComfyUI' to the sys.path
  61.    """
  62.     comfyui_path = find_path("ComfyUI")
  63.     if comfyui_path is not None and os.path.isdir(comfyui_path):
  64.         sys.path.append(comfyui_path)
  65.         print(f"'{comfyui_path}' added to sys.path")
  66.  
  67.  
  68. def add_extra_model_paths() -> None:
  69.     """
  70.    Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.
  71.    """
  72.     try:
  73.         from main import load_extra_path_config
  74.     except ImportError:
  75.         print(
  76.             "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
  77.         )
  78.         from utils.extra_config import load_extra_path_config
  79.  
  80.     extra_model_paths = find_path("extra_model_paths.yaml")
  81.  
  82.     if extra_model_paths is not None:
  83.         load_extra_path_config(extra_model_paths)
  84.     else:
  85.         print("Could not find the extra_model_paths config file.")
  86.  
  87.  
  88. add_comfyui_directory_to_sys_path()
  89. add_extra_model_paths()
  90.  
  91.  
  92. def import_custom_nodes() -> None:
  93.     """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS
  94.  
  95.    This function sets up a new asyncio event loop, initializes the PromptServer,
  96.    creates a PromptQueue, and initializes the custom nodes.
  97.    """
  98.     import asyncio
  99.     import execution
  100.     from nodes import init_extra_nodes
  101.     import server
  102.  
  103.     # Creating a new event loop and setting it as the default loop
  104.     loop = asyncio.new_event_loop()
  105.     asyncio.set_event_loop(loop)
  106.  
  107.     # Creating an instance of PromptServer with the loop
  108.     server_instance = server.PromptServer(loop)
  109.     execution.PromptQueue(server_instance)
  110.  
  111.     # Initializing custom nodes
  112.     init_extra_nodes()
  113.  
  114.  
  115. from nodes import NODE_CLASS_MAPPINGS
  116.  
  117.  
  118. def main():
  119.     import_custom_nodes()
  120.     with torch.inference_mode():
  121.         checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
  122.         checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
  123.             ckpt_name="XL\nightvisionxl_V900.safetensors"
  124.         )
  125.  
  126.         cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
  127.         cliptextencode_6 = cliptextencode.encode(
  128.             text="a woman, up close, neck",
  129.             clip=get_value_at_index(checkpointloadersimple_4, 1),
  130.         )
  131.  
  132.         cliptextencode_7 = cliptextencode.encode(
  133.             text="lowres, low quality, cropped, worst quality, watermark",
  134.             clip=get_value_at_index(checkpointloadersimple_4, 1),
  135.         )
  136.  
  137.         loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
  138.         loadimage_78 = loadimage.load_image(
  139.             image="TINY-SOLITAIRE-PEARL-NECKLACE-FRONT-1_750x-removebg-preview.png"
  140.         )
  141.  
  142.         imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
  143.         imageresize_14 = imageresize.execute(
  144.             width=1024,
  145.             height=1024,
  146.             interpolation="nearest",
  147.             method="stretch",
  148.             condition="always",
  149.             multiple_of=0,
  150.             image=get_value_at_index(loadimage_78, 0),
  151.         )
  152.  
  153.         emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
  154.         emptylatentimage_42 = emptylatentimage.generate(
  155.             width=get_value_at_index(imageresize_14, 1),
  156.             height=get_value_at_index(imageresize_14, 2),
  157.             batch_size=1,
  158.         )
  159.  
  160.         vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
  161.         vaedecode_43 = vaedecode.decode(
  162.             samples=get_value_at_index(emptylatentimage_42, 0),
  163.             vae=get_value_at_index(checkpointloadersimple_4, 2),
  164.         )
  165.  
  166.         easy_imagerembg = NODE_CLASS_MAPPINGS["easy imageRemBg"]()
  167.         easy_imagerembg_12 = easy_imagerembg.remove(
  168.             rem_mode="RMBG-1.4",
  169.             image_output="Preview",
  170.             save_prefix="ComfyUI",
  171.             torchscript_jit=False,
  172.             images=get_value_at_index(imageresize_14, 0),
  173.         )
  174.  
  175.         splitimagewithalpha = NODE_CLASS_MAPPINGS["SplitImageWithAlpha"]()
  176.         splitimagewithalpha_47 = splitimagewithalpha.split_image_with_alpha(
  177.             image=get_value_at_index(easy_imagerembg_12, 0)
  178.         )
  179.  
  180.         imagecompositemasked = NODE_CLASS_MAPPINGS["ImageCompositeMasked"]()
  181.         imagecompositemasked_46 = imagecompositemasked.composite(
  182.             x=0,
  183.             y=0,
  184.             resize_source=False,
  185.             destination=get_value_at_index(vaedecode_43, 0),
  186.             source=get_value_at_index(splitimagewithalpha_47, 0),
  187.             mask=get_value_at_index(easy_imagerembg_12, 1),
  188.         )
  189.  
  190.         vaeencodeargmax = NODE_CLASS_MAPPINGS["VAEEncodeArgMax"]()
  191.         vaeencodeargmax_37 = vaeencodeargmax.encode(
  192.             pixels=get_value_at_index(imagecompositemasked_46, 0),
  193.             vae=get_value_at_index(checkpointloadersimple_4, 2),
  194.         )
  195.  
  196.         easy_ipadapterapply = NODE_CLASS_MAPPINGS["easy ipadapterApply"]()
  197.         ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
  198.  
  199.         for q in range(1):
  200.             easy_ipadapterapply_58 = easy_ipadapterapply.apply(
  201.                 preset="PLUS (high strength)",
  202.                 lora_strength=0.6,
  203.                 provider="CPU",
  204.                 weight=1,
  205.                 weight_faceidv2=1,
  206.                 start_at=0,
  207.                 end_at=1,
  208.                 cache_mode="all",
  209.                 use_tiled=False,
  210.                 model=get_value_at_index(checkpointloadersimple_4, 0),
  211.                 image=get_value_at_index(imagecompositemasked_46, 0),
  212.                 attn_mask=get_value_at_index(easy_imagerembg_12, 1),
  213.             )
  214.  
  215.             ksampler_16 = ksampler.sample(
  216.                 seed=random.randint(1, 2**64),
  217.                 steps=25,
  218.                 cfg=4,
  219.                 sampler_name="dpmpp_2m_sde",
  220.                 scheduler="karras",
  221.                 denoise=0.9,
  222.                 model=get_value_at_index(easy_ipadapterapply_58, 0),
  223.                 positive=get_value_at_index(cliptextencode_6, 0),
  224.                 negative=get_value_at_index(cliptextencode_7, 0),
  225.                 latent_image=get_value_at_index(vaeencodeargmax_37, 0),
  226.             )
  227.  
  228.             vaedecode_17 = vaedecode.decode(
  229.                 samples=get_value_at_index(ksampler_16, 0),
  230.                 vae=get_value_at_index(checkpointloadersimple_4, 2),
  231.             )
  232.  
  233.  
  234. if __name__ == "__main__":
  235.     main()
  236.  
Tags: ComfyUI
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement