summaryrefslogblamecommitdiffstats
path: root/infer.py
blob: 0a219a5b03b82755fc95cad3395926d46aaca630 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
               
              

            
                        
                           
            
                
                     
                           








                                    
                                    
                           
                            
 
                                                                 
                                      
 
                                                          
                                                           
                                                                                
                                                            
                                   
 
                                            
                                     
 
                
                                                
                        
                                         
                                




                                     
                  
                         
                     
                            
                     
                  
                       
                  
                    
                
                          
                   












                                                                  
                         
                                                                                        

                        
                        



                                         


                              
                                

                        
                       


                        




                        
                                                                                        
                        




                                                
                      










                      
                        

                         









                      
                        


                        
                   
                  
                        
                            
                  
                        


                        


                        
                        
                   


                        


                        
                        
                       


                        


                        

                           
                   
                        


                        
                 
                        
                   
     

                 
                                             
                                            
                                    
 
                               
                                            

                                                                                       
 



                                                                      
 





                                                 
                                                  
                                                       
                                                    
                             
     
                                                         

                                                                                                     
 













                                                          

                                                                               
                            

                                              
     

                                                                                    








                                                                                     








                                                               

                                                                                           
 

          















                                                                                 
                                                         






                                                                
         
                                                            
 
                                  
                                                           
 











                                                                                  
 
                                          
 





                                   
                                                         
                                 
                       



                             
                                                     

                                                                                    
 
                       
                                               

                                    
                                                                          
                                                               

                             
                                  
                                                                      
             
                                         
                                  
                                                    
                                                  
         
                                                                          
                                                     
 
                                                
 
                               





                                                  

                                                                    
 
                                         
                                                                         
         
                            
                                                                              
                                                                    
                          
                          

                                                 
                                               
                                                  
                                     



                                          
                                                        
                                                 
 
                                                           
                                                         
                                                     
 

                                 
 
                        
                      
                 
 

                                                                                  
                          
 
                                    
                                                  
                                                      








                                        
                  
 
                              
                       
                                              
                                                                      
                  
            
                                                                  
                                     
                                                 
                      
                          
                                 
                                    
                  
                              
                                 
                  
 

                                                          
                                          
                              
                                 
                  





                                                              
 
                                      
                                                
 
                                      

                                                                                   
 
                                                 
 
                                                           
                                                             
 
                                    





                                 
                        


                          
import argparse
import datetime
import logging
import sys
import shlex
import cmd
from pathlib import Path
from typing import Optional
import torch
import json
import traceback

from PIL import Image
from slugify import slugify
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    PNDMScheduler,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    DDIMScheduler,
    LMSDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    DEISMultistepScheduler,
    UniPCMultistepScheduler,
)
from peft import LoraConfig, LoraModel, set_peft_model_state_dict
from safetensors.torch import load_file
from transformers import CLIPTextModel

from data.keywords import str_to_keywords, keywords_to_str
from models.clip.embeddings import patch_managed_embeddings
from models.clip.tokenizer import MultiCLIPTokenizer
from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
from util.files import load_config, load_embeddings_from_dir
from util.ti import load_embeddings


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True


default_args = {
    "model": "stabilityai/stable-diffusion-2-1",
    "precision": "fp32",
    "ti_embeddings_dir": "embeddings_ti",
    "lora_embeddings_dir": None,
    "output_dir": "output/inference",
    "config": None,
}


default_cmds = {
    "project": "",
    "scheduler": "unipc",
    "subscheduler": None,
    "template": "{}",
    "prompt": None,
    "negative_prompt": None,
    "shuffle": False,
    "image": None,
    "image_noise": 0.7,
    "width": 768,
    "height": 768,
    "batch_size": 1,
    "batch_num": 1,
    "steps": 30,
    "guidance_scale": 7.0,
    "sag_scale": 0,
    "seed": None,
    "config": None,
}


def merge_dicts(d1, *args):
    d1 = d1.copy()

    for d in args:
        d1.update({k: v for (k, v) in d.items() if v is not None})

    return d1


def create_args_parser():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--model",
        type=str,
    )
    parser.add_argument(
        "--precision",
        type=str,
        choices=["fp32", "fp16", "bf16"],
    )
    parser.add_argument(
        "--ti_embeddings_dir",
        type=str,
    )
    parser.add_argument(
        "--lora_embeddings_dir",
        type=str,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
    )
    parser.add_argument(
        "--config",
        type=str,
    )

    return parser


def create_cmd_parser():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--project",
        type=str,
        default=None,
        help="The name of the current project.",
    )
    parser.add_argument(
        "--scheduler",
        type=str,
        choices=[
            "plms",
            "ddim",
            "klms",
            "dpmsm",
            "dpmss",
            "euler_a",
            "kdpm2",
            "kdpm2_a",
            "deis",
            "unipc",
        ],
    )
    parser.add_argument(
        "--subscheduler",
        type=str,
        default=None,
        choices=[
            "plms",
            "ddim",
            "klms",
            "dpmsm",
            "dpmss",
            "euler_a",
            "kdpm2",
            "kdpm2_a",
            "deis",
        ],
    )
    parser.add_argument(
        "--template",
        type=str,
    )
    parser.add_argument(
        "--prompt",
        type=str,
        nargs="+",
    )
    parser.add_argument(
        "--negative_prompt",
        type=str,
        nargs="*",
    )
    parser.add_argument(
        "--shuffle",
        type=bool,
    )
    parser.add_argument(
        "--image",
        type=str,
    )
    parser.add_argument(
        "--image_noise",
        type=float,
    )
    parser.add_argument(
        "--width",
        type=int,
    )
    parser.add_argument(
        "--height",
        type=int,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
    )
    parser.add_argument(
        "--batch_num",
        type=int,
    )
    parser.add_argument(
        "--steps",
        type=int,
    )
    parser.add_argument(
        "--guidance_scale",
        type=float,
    )
    parser.add_argument(
        "--sag_scale",
        type=float,
    )
    parser.add_argument(
        "--seed",
        type=int,
    )
    parser.add_argument(
        "--config",
        type=str,
    )

    return parser


def run_parser(parser, defaults, input=None):
    args = parser.parse_known_args(input)[0]
    conf_args = argparse.Namespace()

    if args.config is not None:
        conf_args = load_config(args.config)
        conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[
            0
        ]

    res = defaults.copy()
    for dict in [vars(conf_args), vars(args)]:
        res.update({k: v for (k, v) in dict.items() if v is not None})

    return argparse.Namespace(**res)


def save_args(basepath, args, extra={}):
    info = {"args": vars(args)}
    info["args"].update(extra)
    with open(f"{basepath}/args.json", "w") as f:
        json.dump(info, f, indent=4)


def load_embeddings_dir(pipeline, embeddings_dir):
    added_tokens, added_ids = load_embeddings_from_dir(
        pipeline.tokenizer,
        pipeline.text_encoder.text_model.embeddings,
        Path(embeddings_dir),
    )
    pipeline.text_encoder.text_model.embeddings.persist()
    print(
        f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
    )


def load_lora(pipeline, path):
    if path is None:
        return

    path = Path(path)

    with open(path / "lora_config.json", "r") as f:
        lora_config = json.load(f)

    tensor_files = list(path.glob("*_end.safetensors"))

    if len(tensor_files) == 0:
        return

    lora_checkpoint_sd = load_file(path / tensor_files[0])
    unet_lora_ds = {
        k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k
    }
    text_encoder_lora_ds = {
        k.replace("text_encoder_", ""): v
        for k, v in lora_checkpoint_sd.items()
        if "text_encoder_" in k
    }
    ti_lora_ds = {
        k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k
    }

    unet_config = LoraConfig(**lora_config["peft_config"])
    pipeline.unet = LoraModel(unet_config, pipeline.unet)
    set_peft_model_state_dict(pipeline.unet, unet_lora_ds)

    if "text_encoder_peft_config" in lora_config:
        text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
        pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder)
        set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds)

    tokens = [k for k, _ in ti_lora_ds]
    token_embeddings = [v for _, v in ti_lora_ds]

    added_tokens, added_ids = load_embeddings(
        tokenizer=pipeline.tokenizer,
        embeddings=pipeline.text_encoder.text_model.embeddings,
        tokens=tokens,
        token_embeddings=token_embeddings,
    )
    pipeline.text_encoder.text_model.embeddings.persist()
    print(
        f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}"
    )

    return


def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None):
    if scheduler == "plms":
        return PNDMScheduler.from_config(config)
    elif scheduler == "klms":
        return LMSDiscreteScheduler.from_config(config)
    elif scheduler == "ddim":
        return DDIMScheduler.from_config(config)
    elif scheduler == "dpmsm":
        return DPMSolverMultistepScheduler.from_config(config)
    elif scheduler == "dpmss":
        return DPMSolverSinglestepScheduler.from_config(config)
    elif scheduler == "euler_a":
        return EulerAncestralDiscreteScheduler.from_config(config)
    elif scheduler == "kdpm2":
        return KDPM2DiscreteScheduler.from_config(config)
    elif scheduler == "kdpm2_a":
        return KDPM2AncestralDiscreteScheduler.from_config(config)
    elif scheduler == "deis":
        return DEISMultistepScheduler.from_config(config)
    elif scheduler == "unipc":
        if subscheduler is None:
            return UniPCMultistepScheduler.from_config(config)
        else:
            return UniPCMultistepScheduler.from_config(
                config,
                solver_p=create_scheduler(config, subscheduler),
            )
    else:
        raise ValueError(f'Unknown scheduler "{scheduler}"')


def create_pipeline(model, dtype):
    print(f"Loading Stable Diffusion pipeline: {model}...")

    tokenizer = MultiCLIPTokenizer.from_pretrained(
        model, subfolder="tokenizer", torch_dtype=dtype
    )
    text_encoder = CLIPTextModel.from_pretrained(
        model, subfolder="text_encoder", torch_dtype=dtype
    )
    vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype)
    unet = UNet2DConditionModel.from_pretrained(
        model, subfolder="unet", torch_dtype=dtype
    )
    scheduler = DDIMScheduler.from_pretrained(
        model, subfolder="scheduler", torch_dtype=dtype
    )

    patch_managed_embeddings(text_encoder)

    pipeline = VlpnStableDiffusion(
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer,
        scheduler=scheduler,
    )
    pipeline.enable_xformers_memory_efficient_attention()
    # pipeline.unet = torch.compile(pipeline.unet)
    pipeline.enable_vae_slicing()
    pipeline.to("cuda")

    print("Pipeline loaded.")

    return pipeline


def shuffle_prompts(prompts: list[str]) -> list[str]:
    return [
        keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts
    ]


@torch.inference_mode()
def generate(output_dir: Path, pipeline, args):
    if isinstance(args.prompt, str):
        args.prompt = [args.prompt]

    args.prompt = [args.template.format(prompt) for prompt in args.prompt]

    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    image_dir = []

    if len(args.prompt) != 1:
        if len(args.project) != 0:
            output_dir = output_dir / f"{now}_{slugify(args.project)}"
        else:
            output_dir = output_dir / now

        for prompt in args.prompt:
            dir = output_dir / slugify(prompt)[:100]
            dir.mkdir(parents=True, exist_ok=True)
            image_dir.append(dir)
    else:
        output_dir = output_dir / f"{now}_{slugify(args.prompt[0])[:100]}"
        output_dir.mkdir(parents=True, exist_ok=True)
        image_dir.append(output_dir)

    args.seed = args.seed or torch.random.seed()

    save_args(output_dir, args)

    if args.image:
        init_image = Image.open(args.image)
        if not init_image.mode == "RGB":
            init_image = init_image.convert("RGB")
    else:
        init_image = None

    pipeline.scheduler = create_scheduler(
        pipeline.scheduler.config, args.scheduler, args.subscheduler
    )

    for i in range(args.batch_num):
        pipeline.set_progress_bar_config(
            desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True
        )

        seed = args.seed + i
        prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt
        generator = torch.Generator(device="cuda").manual_seed(seed)
        images = pipeline(
            prompt=prompt,
            negative_prompt=args.negative_prompt,
            height=args.height,
            width=args.width,
            generator=generator,
            guidance_scale=args.guidance_scale,
            num_images_per_prompt=args.batch_size,
            num_inference_steps=args.steps,
            sag_scale=args.sag_scale,
            image=init_image,
            strength=args.image_noise,
        ).images

        for j, image in enumerate(images):
            basename = f"{seed}_{j // len(args.prompt)}"
            dir = image_dir[j % len(args.prompt)]

            image.save(dir / f"{basename}.png")
            image.save(dir / f"{basename}.jpg", quality=85)
            with open(dir / f"{basename}.txt", "w") as f:
                f.write(prompt[j % len(args.prompt)])

    if torch.cuda.is_available():
        torch.cuda.empty_cache()


class CmdParse(cmd.Cmd):
    prompt = "dream> "
    commands = []

    def __init__(
        self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser
    ):
        super().__init__()

        self.output_dir = output_dir
        self.ti_embeddings_dir = ti_embeddings_dir
        self.lora_embeddings_dir = lora_embeddings_dir
        self.pipeline = pipeline
        self.parser = parser

    def default(self, line):
        line = line.replace("'", "\\'")

        try:
            elements = shlex.split(line)
        except ValueError as e:
            print(str(e))
            return

        if elements[0] == "q":
            return True

        if elements[0] == "reload_embeddings":
            load_embeddings_dir(self.pipeline, self.ti_embeddings_dir)
            return

        try:
            args = run_parser(self.parser, default_cmds, elements)

            if len(args.prompt) == 0:
                print("Try again with a prompt!")
                return
        except SystemExit:
            traceback.print_exc()
            self.parser.print_help()
            return
        except Exception as e:
            traceback.print_exc()
            return

        try:
            generate(self.output_dir, self.pipeline, args)
        except KeyboardInterrupt:
            print("Generation cancelled.")
        except Exception as e:
            traceback.print_exc()
            return

    def do_exit(self, line):
        return True


def main():
    logging.basicConfig(stream=sys.stdout, level=logging.WARN)

    args_parser = create_args_parser()
    args = run_parser(args_parser, default_args)

    output_dir = Path(args.output_dir)
    dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[
        args.precision
    ]

    pipeline = create_pipeline(args.model, dtype)

    # load_embeddings_dir(pipeline, args.ti_embeddings_dir)
    # load_lora(pipeline, args.lora_embeddings_dir)
    # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)

    cmd_parser = create_cmd_parser()
    cmd_prompt = CmdParse(
        output_dir,
        args.ti_embeddings_dir,
        args.lora_embeddings_dir,
        pipeline,
        cmd_parser,
    )
    cmd_prompt.cmdloop()


if __name__ == "__main__":
    main()