diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
| commit | 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch) | |
| tree | ad186862f5095663966dd1d42455023080aa0c4e /infer.py | |
| parent | Better sample file structure (diff) | |
| download | textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2 textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip | |
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 111 |
1 files changed, 80 insertions, 31 deletions
| @@ -1,18 +1,15 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | ||
| 3 | from pathlib import Path | 4 | from pathlib import Path |
| 4 | from torch import autocast | 5 | from torch import autocast |
| 5 | from diffusers import StableDiffusionPipeline | ||
| 6 | import torch | 6 | import torch |
| 7 | import json | 7 | import json |
| 8 | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler | 8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
| 9 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
| 10 | from slugify import slugify | 10 | from slugify import slugify |
| 11 | from pipelines.stable_diffusion.no_check import NoCheck | 11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
| 12 | 12 | from schedulers.scheduling_euler_a import EulerAScheduler | |
| 13 | model_id = "path-to-your-trained-model" | ||
| 14 | |||
| 15 | prompt = "A photo of sks dog in a bucket" | ||
| 16 | 13 | ||
| 17 | 14 | ||
| 18 | def parse_args(): | 15 | def parse_args(): |
| @@ -30,6 +27,21 @@ def parse_args(): | |||
| 30 | default=None, | 27 | default=None, |
| 31 | ) | 28 | ) |
| 32 | parser.add_argument( | 29 | parser.add_argument( |
| 30 | "--negative_prompt", | ||
| 31 | type=str, | ||
| 32 | default=None, | ||
| 33 | ) | ||
| 34 | parser.add_argument( | ||
| 35 | "--width", | ||
| 36 | type=int, | ||
| 37 | default=512, | ||
| 38 | ) | ||
| 39 | parser.add_argument( | ||
| 40 | "--height", | ||
| 41 | type=int, | ||
| 42 | default=512, | ||
| 43 | ) | ||
| 44 | parser.add_argument( | ||
| 33 | "--batch_size", | 45 | "--batch_size", |
| 34 | type=int, | 46 | type=int, |
| 35 | default=1, | 47 | default=1, |
| @@ -42,17 +54,28 @@ def parse_args(): | |||
| 42 | parser.add_argument( | 54 | parser.add_argument( |
| 43 | "--steps", | 55 | "--steps", |
| 44 | type=int, | 56 | type=int, |
| 45 | default=80, | 57 | default=120, |
| 58 | ) | ||
| 59 | parser.add_argument( | ||
| 60 | "--scheduler", | ||
| 61 | type=str, | ||
| 62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
| 63 | default="euler_a", | ||
| 46 | ) | 64 | ) |
| 47 | parser.add_argument( | 65 | parser.add_argument( |
| 48 | "--scale", | 66 | "--guidance_scale", |
| 49 | type=int, | 67 | type=int, |
| 50 | default=7.5, | 68 | default=7.5, |
| 51 | ) | 69 | ) |
| 52 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--clip_guidance_scale", | ||
| 72 | type=int, | ||
| 73 | default=100, | ||
| 74 | ) | ||
| 75 | parser.add_argument( | ||
| 53 | "--seed", | 76 | "--seed", |
| 54 | type=int, | 77 | type=int, |
| 55 | default=None, | 78 | default=torch.random.seed(), |
| 56 | ) | 79 | ) |
| 57 | parser.add_argument( | 80 | parser.add_argument( |
| 58 | "--output_dir", | 81 | "--output_dir", |
| @@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}): | |||
| 81 | json.dump(info, f, indent=4) | 104 | json.dump(info, f, indent=4) |
| 82 | 105 | ||
| 83 | 106 | ||
| 84 | def main(): | 107 | def gen(args, output_dir): |
| 85 | args = parse_args() | ||
| 86 | |||
| 87 | seed = args.seed or torch.random.seed() | ||
| 88 | |||
| 89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 90 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
| 91 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 92 | save_args(output_dir, args) | ||
| 93 | |||
| 94 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) |
| 95 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | 109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) |
| 110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
| 96 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | 111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) |
| 97 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | 112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) |
| 98 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) | 113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( |
| 114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
| 99 | 115 | ||
| 100 | pipeline = StableDiffusionPipeline( | 116 | if args.scheduler == "plms": |
| 117 | scheduler = PNDMScheduler( | ||
| 118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
| 119 | ) | ||
| 120 | elif args.scheduler == "klms": | ||
| 121 | scheduler = LMSDiscreteScheduler( | ||
| 122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 123 | ) | ||
| 124 | elif args.scheduler == "ddim": | ||
| 125 | scheduler = DDIMScheduler( | ||
| 126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
| 127 | ) | ||
| 128 | else: | ||
| 129 | scheduler = EulerAScheduler( | ||
| 130 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
| 131 | ) | ||
| 132 | |||
| 133 | pipeline = CLIPGuidedStableDiffusion( | ||
| 101 | text_encoder=text_encoder, | 134 | text_encoder=text_encoder, |
| 102 | vae=vae, | 135 | vae=vae, |
| 103 | unet=unet, | 136 | unet=unet, |
| 104 | tokenizer=tokenizer, | 137 | tokenizer=tokenizer, |
| 105 | scheduler=PNDMScheduler( | 138 | clip_model=clip_model, |
| 106 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 139 | scheduler=scheduler, |
| 107 | ), | ||
| 108 | safety_checker=NoCheck(), | ||
| 109 | feature_extractor=feature_extractor | 140 | feature_extractor=feature_extractor |
| 110 | ) | 141 | ) |
| 111 | pipeline.enable_attention_slicing() | 142 | pipeline.enable_attention_slicing() |
| @@ -113,16 +144,34 @@ def main(): | |||
| 113 | 144 | ||
| 114 | with autocast("cuda"): | 145 | with autocast("cuda"): |
| 115 | for i in range(args.batch_num): | 146 | for i in range(args.batch_num): |
| 116 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
| 117 | images = pipeline( | 148 | images = pipeline( |
| 118 | [args.prompt] * args.batch_size, | 149 | prompt=[args.prompt] * args.batch_size, |
| 150 | height=args.height, | ||
| 151 | width=args.width, | ||
| 152 | negative_prompt=args.negative_prompt, | ||
| 119 | num_inference_steps=args.steps, | 153 | num_inference_steps=args.steps, |
| 120 | guidance_scale=args.scale, | 154 | guidance_scale=args.guidance_scale, |
| 155 | clip_guidance_scale=args.clip_guidance_scale, | ||
| 121 | generator=generator, | 156 | generator=generator, |
| 122 | ).images | 157 | ).images |
| 123 | 158 | ||
| 124 | for j, image in enumerate(images): | 159 | for j, image in enumerate(images): |
| 125 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
| 161 | |||
| 162 | |||
| 163 | def main(): | ||
| 164 | args = parse_args() | ||
| 165 | |||
| 166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
| 168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 169 | |||
| 170 | save_args(output_dir, args) | ||
| 171 | |||
| 172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
| 173 | |||
| 174 | gen(args, output_dir) | ||
| 126 | 175 | ||
| 127 | 176 | ||
| 128 | if __name__ == "__main__": | 177 | if __name__ == "__main__": |
