diff options
author | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
commit | 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch) | |
tree | ad186862f5095663966dd1d42455023080aa0c4e /infer.py | |
parent | Better sample file structure (diff) | |
download | textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2 textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip |
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 111 |
1 files changed, 80 insertions, 31 deletions
@@ -1,18 +1,15 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | ||
3 | from pathlib import Path | 4 | from pathlib import Path |
4 | from torch import autocast | 5 | from torch import autocast |
5 | from diffusers import StableDiffusionPipeline | ||
6 | import torch | 6 | import torch |
7 | import json | 7 | import json |
8 | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler | 8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
9 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
10 | from slugify import slugify | 10 | from slugify import slugify |
11 | from pipelines.stable_diffusion.no_check import NoCheck | 11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
12 | 12 | from schedulers.scheduling_euler_a import EulerAScheduler | |
13 | model_id = "path-to-your-trained-model" | ||
14 | |||
15 | prompt = "A photo of sks dog in a bucket" | ||
16 | 13 | ||
17 | 14 | ||
18 | def parse_args(): | 15 | def parse_args(): |
@@ -30,6 +27,21 @@ def parse_args(): | |||
30 | default=None, | 27 | default=None, |
31 | ) | 28 | ) |
32 | parser.add_argument( | 29 | parser.add_argument( |
30 | "--negative_prompt", | ||
31 | type=str, | ||
32 | default=None, | ||
33 | ) | ||
34 | parser.add_argument( | ||
35 | "--width", | ||
36 | type=int, | ||
37 | default=512, | ||
38 | ) | ||
39 | parser.add_argument( | ||
40 | "--height", | ||
41 | type=int, | ||
42 | default=512, | ||
43 | ) | ||
44 | parser.add_argument( | ||
33 | "--batch_size", | 45 | "--batch_size", |
34 | type=int, | 46 | type=int, |
35 | default=1, | 47 | default=1, |
@@ -42,17 +54,28 @@ def parse_args(): | |||
42 | parser.add_argument( | 54 | parser.add_argument( |
43 | "--steps", | 55 | "--steps", |
44 | type=int, | 56 | type=int, |
45 | default=80, | 57 | default=120, |
46 | ) | 58 | ) |
47 | parser.add_argument( | 59 | parser.add_argument( |
48 | "--scale", | 60 | "--scheduler", |
61 | type=str, | ||
62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
63 | default="euler_a", | ||
64 | ) | ||
65 | parser.add_argument( | ||
66 | "--guidance_scale", | ||
49 | type=int, | 67 | type=int, |
50 | default=7.5, | 68 | default=7.5, |
51 | ) | 69 | ) |
52 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--clip_guidance_scale", | ||
72 | type=int, | ||
73 | default=100, | ||
74 | ) | ||
75 | parser.add_argument( | ||
53 | "--seed", | 76 | "--seed", |
54 | type=int, | 77 | type=int, |
55 | default=None, | 78 | default=torch.random.seed(), |
56 | ) | 79 | ) |
57 | parser.add_argument( | 80 | parser.add_argument( |
58 | "--output_dir", | 81 | "--output_dir", |
@@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}): | |||
81 | json.dump(info, f, indent=4) | 104 | json.dump(info, f, indent=4) |
82 | 105 | ||
83 | 106 | ||
84 | def main(): | 107 | def gen(args, output_dir): |
85 | args = parse_args() | ||
86 | |||
87 | seed = args.seed or torch.random.seed() | ||
88 | |||
89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
90 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
91 | output_dir.mkdir(parents=True, exist_ok=True) | ||
92 | save_args(output_dir, args) | ||
93 | |||
94 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) |
95 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | 109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) |
110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
96 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | 111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) |
97 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | 112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) |
98 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) | 113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( |
114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
99 | 115 | ||
100 | pipeline = StableDiffusionPipeline( | 116 | if args.scheduler == "plms": |
117 | scheduler = PNDMScheduler( | ||
118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
119 | ) | ||
120 | elif args.scheduler == "klms": | ||
121 | scheduler = LMSDiscreteScheduler( | ||
122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
123 | ) | ||
124 | elif args.scheduler == "ddim": | ||
125 | scheduler = DDIMScheduler( | ||
126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
127 | ) | ||
128 | else: | ||
129 | scheduler = EulerAScheduler( | ||
130 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
131 | ) | ||
132 | |||
133 | pipeline = CLIPGuidedStableDiffusion( | ||
101 | text_encoder=text_encoder, | 134 | text_encoder=text_encoder, |
102 | vae=vae, | 135 | vae=vae, |
103 | unet=unet, | 136 | unet=unet, |
104 | tokenizer=tokenizer, | 137 | tokenizer=tokenizer, |
105 | scheduler=PNDMScheduler( | 138 | clip_model=clip_model, |
106 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 139 | scheduler=scheduler, |
107 | ), | ||
108 | safety_checker=NoCheck(), | ||
109 | feature_extractor=feature_extractor | 140 | feature_extractor=feature_extractor |
110 | ) | 141 | ) |
111 | pipeline.enable_attention_slicing() | 142 | pipeline.enable_attention_slicing() |
@@ -113,16 +144,34 @@ def main(): | |||
113 | 144 | ||
114 | with autocast("cuda"): | 145 | with autocast("cuda"): |
115 | for i in range(args.batch_num): | 146 | for i in range(args.batch_num): |
116 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
117 | images = pipeline( | 148 | images = pipeline( |
118 | [args.prompt] * args.batch_size, | 149 | prompt=[args.prompt] * args.batch_size, |
150 | height=args.height, | ||
151 | width=args.width, | ||
152 | negative_prompt=args.negative_prompt, | ||
119 | num_inference_steps=args.steps, | 153 | num_inference_steps=args.steps, |
120 | guidance_scale=args.scale, | 154 | guidance_scale=args.guidance_scale, |
155 | clip_guidance_scale=args.clip_guidance_scale, | ||
121 | generator=generator, | 156 | generator=generator, |
122 | ).images | 157 | ).images |
123 | 158 | ||
124 | for j, image in enumerate(images): | 159 | for j, image in enumerate(images): |
125 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
161 | |||
162 | |||
163 | def main(): | ||
164 | args = parse_args() | ||
165 | |||
166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
169 | |||
170 | save_args(output_dir, args) | ||
171 | |||
172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
173 | |||
174 | gen(args, output_dir) | ||
126 | 175 | ||
127 | 176 | ||
128 | if __name__ == "__main__": | 177 | if __name__ == "__main__": |