diff options
author | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
commit | 5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch) | |
tree | a3461a4f1a04fba52ec8fde8b7b07095c7422d85 /infer.py | |
parent | Added custom SD pipeline + euler_a scheduler (diff) | |
download | textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2 textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip |
Made inference script interactive
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 154 |
1 files changed, 109 insertions, 45 deletions
@@ -1,18 +1,21 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import sys | ||
5 | import shlex | ||
6 | import cmd | ||
4 | from pathlib import Path | 7 | from pathlib import Path |
5 | from torch import autocast | 8 | from torch import autocast |
6 | import torch | 9 | import torch |
7 | import json | 10 | import json |
8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
10 | from slugify import slugify | 13 | from slugify import slugify |
11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
12 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
13 | 16 | ||
14 | 17 | ||
15 | def parse_args(): | 18 | def create_args_parser(): |
16 | parser = argparse.ArgumentParser( | 19 | parser = argparse.ArgumentParser( |
17 | description="Simple example of a training script." | 20 | description="Simple example of a training script." |
18 | ) | 21 | ) |
@@ -22,6 +25,30 @@ def parse_args(): | |||
22 | default=None, | 25 | default=None, |
23 | ) | 26 | ) |
24 | parser.add_argument( | 27 | parser.add_argument( |
28 | "--scheduler", | ||
29 | type=str, | ||
30 | choices=["plms", "ddim", "klms", "euler_a"], | ||
31 | default="euler_a", | ||
32 | ) | ||
33 | parser.add_argument( | ||
34 | "--output_dir", | ||
35 | type=str, | ||
36 | default="output/inference", | ||
37 | ) | ||
38 | parser.add_argument( | ||
39 | "--config", | ||
40 | type=str, | ||
41 | default=None, | ||
42 | ) | ||
43 | |||
44 | return parser | ||
45 | |||
46 | |||
47 | def create_cmd_parser(): | ||
48 | parser = argparse.ArgumentParser( | ||
49 | description="Simple example of a training script." | ||
50 | ) | ||
51 | parser.add_argument( | ||
25 | "--prompt", | 52 | "--prompt", |
26 | type=str, | 53 | type=str, |
27 | default=None, | 54 | default=None, |
@@ -49,28 +76,17 @@ def parse_args(): | |||
49 | parser.add_argument( | 76 | parser.add_argument( |
50 | "--batch_num", | 77 | "--batch_num", |
51 | type=int, | 78 | type=int, |
52 | default=50, | 79 | default=1, |
53 | ) | 80 | ) |
54 | parser.add_argument( | 81 | parser.add_argument( |
55 | "--steps", | 82 | "--steps", |
56 | type=int, | 83 | type=int, |
57 | default=120, | 84 | default=70, |
58 | ) | ||
59 | parser.add_argument( | ||
60 | "--scheduler", | ||
61 | type=str, | ||
62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
63 | default="euler_a", | ||
64 | ) | 85 | ) |
65 | parser.add_argument( | 86 | parser.add_argument( |
66 | "--guidance_scale", | 87 | "--guidance_scale", |
67 | type=int, | 88 | type=int, |
68 | default=7.5, | 89 | default=7, |
69 | ) | ||
70 | parser.add_argument( | ||
71 | "--clip_guidance_scale", | ||
72 | type=int, | ||
73 | default=100, | ||
74 | ) | 90 | ) |
75 | parser.add_argument( | 91 | parser.add_argument( |
76 | "--seed", | 92 | "--seed", |
@@ -78,21 +94,21 @@ def parse_args(): | |||
78 | default=torch.random.seed(), | 94 | default=torch.random.seed(), |
79 | ) | 95 | ) |
80 | parser.add_argument( | 96 | parser.add_argument( |
81 | "--output_dir", | ||
82 | type=str, | ||
83 | default="output/inference", | ||
84 | ) | ||
85 | parser.add_argument( | ||
86 | "--config", | 97 | "--config", |
87 | type=str, | 98 | type=str, |
88 | default=None, | 99 | default=None, |
89 | ) | 100 | ) |
90 | 101 | ||
91 | args = parser.parse_args() | 102 | return parser |
103 | |||
104 | |||
105 | def run_parser(parser, input=None): | ||
106 | args = parser.parse_known_args(input)[0] | ||
107 | |||
92 | if args.config is not None: | 108 | if args.config is not None: |
93 | with open(args.config, 'rt') as f: | 109 | with open(args.config, 'rt') as f: |
94 | args = parser.parse_args( | 110 | args = parser.parse_known_args( |
95 | namespace=argparse.Namespace(**json.load(f)["args"])) | 111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] |
96 | 112 | ||
97 | return args | 113 | return args |
98 | 114 | ||
@@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}): | |||
104 | json.dump(info, f, indent=4) | 120 | json.dump(info, f, indent=4) |
105 | 121 | ||
106 | 122 | ||
107 | def gen(args, output_dir): | 123 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): |
108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 124 | print("Loading Stable Diffusion pipeline...") |
109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | ||
110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | ||
112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | ||
113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( | ||
114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
115 | 125 | ||
116 | if args.scheduler == "plms": | 126 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
127 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | ||
128 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | ||
129 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | ||
130 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
131 | |||
132 | if scheduler == "plms": | ||
117 | scheduler = PNDMScheduler( | 133 | scheduler = PNDMScheduler( |
118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 134 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
119 | ) | 135 | ) |
120 | elif args.scheduler == "klms": | 136 | elif scheduler == "klms": |
121 | scheduler = LMSDiscreteScheduler( | 137 | scheduler = LMSDiscreteScheduler( |
122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 138 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
123 | ) | 139 | ) |
124 | elif args.scheduler == "ddim": | 140 | elif scheduler == "ddim": |
125 | scheduler = DDIMScheduler( | 141 | scheduler = DDIMScheduler( |
126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 142 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
127 | ) | 143 | ) |
@@ -135,13 +151,24 @@ def gen(args, output_dir): | |||
135 | vae=vae, | 151 | vae=vae, |
136 | unet=unet, | 152 | unet=unet, |
137 | tokenizer=tokenizer, | 153 | tokenizer=tokenizer, |
138 | clip_model=clip_model, | ||
139 | scheduler=scheduler, | 154 | scheduler=scheduler, |
140 | feature_extractor=feature_extractor | 155 | feature_extractor=feature_extractor |
141 | ) | 156 | ) |
142 | pipeline.enable_attention_slicing() | 157 | pipeline.enable_attention_slicing() |
143 | pipeline.to("cuda") | 158 | pipeline.to("cuda") |
144 | 159 | ||
160 | print("Pipeline loaded.") | ||
161 | |||
162 | return pipeline | ||
163 | |||
164 | |||
165 | def generate(output_dir, pipeline, args): | ||
166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
169 | |||
170 | save_args(output_dir, args) | ||
171 | |||
145 | with autocast("cuda"): | 172 | with autocast("cuda"): |
146 | for i in range(args.batch_num): | 173 | for i in range(args.batch_num): |
147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
@@ -152,7 +179,6 @@ def gen(args, output_dir): | |||
152 | negative_prompt=args.negative_prompt, | 179 | negative_prompt=args.negative_prompt, |
153 | num_inference_steps=args.steps, | 180 | num_inference_steps=args.steps, |
154 | guidance_scale=args.guidance_scale, | 181 | guidance_scale=args.guidance_scale, |
155 | clip_guidance_scale=args.clip_guidance_scale, | ||
156 | generator=generator, | 182 | generator=generator, |
157 | ).images | 183 | ).images |
158 | 184 | ||
@@ -160,18 +186,56 @@ def gen(args, output_dir): | |||
160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
161 | 187 | ||
162 | 188 | ||
163 | def main(): | 189 | class CmdParse(cmd.Cmd): |
164 | args = parse_args() | 190 | prompt = 'dream> ' |
191 | commands = [] | ||
165 | 192 | ||
166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 193 | def __init__(self, output_dir, pipeline, parser): |
167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 194 | super().__init__() |
168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
169 | 195 | ||
170 | save_args(output_dir, args) | 196 | self.output_dir = output_dir |
197 | self.pipeline = pipeline | ||
198 | self.parser = parser | ||
199 | |||
200 | def default(self, line): | ||
201 | line = line.replace("'", "\\'") | ||
202 | |||
203 | try: | ||
204 | elements = shlex.split(line) | ||
205 | except ValueError as e: | ||
206 | print(str(e)) | ||
207 | |||
208 | if elements[0] == 'q': | ||
209 | return True | ||
210 | |||
211 | try: | ||
212 | args = run_parser(self.parser, elements) | ||
213 | except SystemExit: | ||
214 | self.parser.print_help() | ||
215 | |||
216 | if len(args.prompt) == 0: | ||
217 | print('Try again with a prompt!') | ||
218 | |||
219 | try: | ||
220 | generate(self.output_dir, self.pipeline, args) | ||
221 | except KeyboardInterrupt: | ||
222 | print('Generation cancelled.') | ||
223 | |||
224 | def do_exit(self, line): | ||
225 | return True | ||
226 | |||
227 | |||
228 | def main(): | ||
229 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | ||
171 | 230 | ||
172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 231 | args_parser = create_args_parser() |
232 | args = run_parser(args_parser) | ||
233 | output_dir = Path(args.output_dir) | ||
173 | 234 | ||
174 | gen(args, output_dir) | 235 | pipeline = create_pipeline(args.model, args.scheduler) |
236 | cmd_parser = create_cmd_parser() | ||
237 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | ||
238 | cmd_prompt.cmdloop() | ||
175 | 239 | ||
176 | 240 | ||
177 | if __name__ == "__main__": | 241 | if __name__ == "__main__": |