diff options
-rw-r--r-- | dreambooth.py | 54 | ||||
-rw-r--r-- | environment.yaml | 4 | ||||
-rw-r--r-- | infer.py | 52 |
3 files changed, 63 insertions, 47 deletions
diff --git a/dreambooth.py b/dreambooth.py index 79b3d2c..2b8a35e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -859,7 +859,14 @@ def main(): | |||
859 | # Only show the progress bar once on each machine. | 859 | # Only show the progress bar once on each machine. |
860 | 860 | ||
861 | global_step = 0 | 861 | global_step = 0 |
862 | min_val_loss = np.inf | 862 | |
863 | total_loss = 0.0 | ||
864 | total_acc = 0.0 | ||
865 | |||
866 | total_loss_val = 0.0 | ||
867 | total_acc_val = 0.0 | ||
868 | |||
869 | max_acc_val = 0.0 | ||
863 | 870 | ||
864 | checkpointer = Checkpointer( | 871 | checkpointer = Checkpointer( |
865 | datamodule=datamodule, | 872 | datamodule=datamodule, |
@@ -905,7 +912,6 @@ def main(): | |||
905 | 912 | ||
906 | unet.train() | 913 | unet.train() |
907 | text_encoder.train() | 914 | text_encoder.train() |
908 | train_loss = 0.0 | ||
909 | 915 | ||
910 | sample_checkpoint = False | 916 | sample_checkpoint = False |
911 | 917 | ||
@@ -978,8 +984,11 @@ def main(): | |||
978 | ema_unet.step(unet) | 984 | ema_unet.step(unet) |
979 | optimizer.zero_grad(set_to_none=True) | 985 | optimizer.zero_grad(set_to_none=True) |
980 | 986 | ||
981 | loss = loss.detach().item() | 987 | acc = (noise_pred == latents).float() |
982 | train_loss += loss | 988 | acc = acc.mean() |
989 | |||
990 | total_loss += loss.item() | ||
991 | total_acc += acc.item() | ||
983 | 992 | ||
984 | # Checks if the accelerator has performed an optimization step behind the scenes | 993 | # Checks if the accelerator has performed an optimization step behind the scenes |
985 | if accelerator.sync_gradients: | 994 | if accelerator.sync_gradients: |
@@ -996,7 +1005,10 @@ def main(): | |||
996 | sample_checkpoint = True | 1005 | sample_checkpoint = True |
997 | 1006 | ||
998 | logs = { | 1007 | logs = { |
999 | "train/loss": loss, | 1008 | "train/loss": total_loss / global_step, |
1009 | "train/acc": total_acc / global_step, | ||
1010 | "train/cur_loss": loss.item(), | ||
1011 | "train/cur_acc": acc.item(), | ||
1000 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1012 | "lr/unet": lr_scheduler.get_last_lr()[0], |
1001 | "lr/text": lr_scheduler.get_last_lr()[1] | 1013 | "lr/text": lr_scheduler.get_last_lr()[1] |
1002 | } | 1014 | } |
@@ -1010,13 +1022,10 @@ def main(): | |||
1010 | if global_step >= args.max_train_steps: | 1022 | if global_step >= args.max_train_steps: |
1011 | break | 1023 | break |
1012 | 1024 | ||
1013 | train_loss /= len(train_dataloader) | ||
1014 | |||
1015 | accelerator.wait_for_everyone() | 1025 | accelerator.wait_for_everyone() |
1016 | 1026 | ||
1017 | unet.eval() | 1027 | unet.eval() |
1018 | text_encoder.eval() | 1028 | text_encoder.eval() |
1019 | val_loss = 0.0 | ||
1020 | 1029 | ||
1021 | with torch.autocast("cuda"), torch.inference_mode(): | 1030 | with torch.autocast("cuda"), torch.inference_mode(): |
1022 | for step, batch in enumerate(val_dataloader): | 1031 | for step, batch in enumerate(val_dataloader): |
@@ -1039,28 +1048,41 @@ def main(): | |||
1039 | 1048 | ||
1040 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1049 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
1041 | 1050 | ||
1042 | loss = loss.detach().item() | 1051 | acc = (noise_pred == latents).float() |
1043 | val_loss += loss | 1052 | acc = acc.mean() |
1053 | |||
1054 | total_loss_val += loss.item() | ||
1055 | total_acc_val += acc.item() | ||
1044 | 1056 | ||
1045 | if accelerator.sync_gradients: | 1057 | if accelerator.sync_gradients: |
1046 | local_progress_bar.update(1) | 1058 | local_progress_bar.update(1) |
1047 | global_progress_bar.update(1) | 1059 | global_progress_bar.update(1) |
1048 | 1060 | ||
1049 | logs = {"val/loss": loss} | 1061 | logs = { |
1062 | "val/loss": total_loss_val / global_step, | ||
1063 | "val/acc": total_acc_val / global_step, | ||
1064 | "val/cur_loss": loss.item(), | ||
1065 | "val/cur_acc": acc.item(), | ||
1066 | } | ||
1050 | local_progress_bar.set_postfix(**logs) | 1067 | local_progress_bar.set_postfix(**logs) |
1051 | 1068 | ||
1052 | val_loss /= len(val_dataloader) | 1069 | val_step = (epoch + 1) * len(val_dataloader) |
1070 | avg_acc_val = total_acc_val / val_step | ||
1071 | avg_loss_val = total_loss_val / val_step | ||
1053 | 1072 | ||
1054 | accelerator.log({"val/loss": val_loss}, step=global_step) | 1073 | accelerator.log({ |
1074 | "val/loss": avg_loss_val, | ||
1075 | "val/acc": avg_acc_val, | ||
1076 | }, step=global_step) | ||
1055 | 1077 | ||
1056 | local_progress_bar.clear() | 1078 | local_progress_bar.clear() |
1057 | global_progress_bar.clear() | 1079 | global_progress_bar.clear() |
1058 | 1080 | ||
1059 | if min_val_loss > val_loss: | 1081 | if avg_acc_val > max_acc_val: |
1060 | accelerator.print( | 1082 | accelerator.print( |
1061 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 1083 | f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") |
1062 | checkpointer.save_embedding(global_step, "milestone") | 1084 | checkpointer.save_embedding(global_step, "milestone") |
1063 | min_val_loss = val_loss | 1085 | max_acc_val = avg_acc_val |
1064 | 1086 | ||
1065 | if sample_checkpoint and accelerator.is_main_process: | 1087 | if sample_checkpoint and accelerator.is_main_process: |
1066 | checkpointer.save_samples( | 1088 | checkpointer.save_samples( |
diff --git a/environment.yaml b/environment.yaml index 7aa5312..4972ebd 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -11,7 +11,7 @@ dependencies: | |||
11 | - pytorch=1.12.1 | 11 | - pytorch=1.12.1 |
12 | - torchvision=0.13.1 | 12 | - torchvision=0.13.1 |
13 | - pandas=1.4.3 | 13 | - pandas=1.4.3 |
14 | - xformers=0.0.14.dev315 | 14 | - xformers=0.0.15.dev337 |
15 | - pip: | 15 | - pip: |
16 | - -e . | 16 | - -e . |
17 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | 17 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers |
@@ -35,4 +35,4 @@ dependencies: | |||
35 | - torch-fidelity==0.3.0 | 35 | - torch-fidelity==0.3.0 |
36 | - torchmetrics==0.9.3 | 36 | - torchmetrics==0.9.3 |
37 | - transformers==4.23.1 | 37 | - transformers==4.23.1 |
38 | - triton==2.0.0.dev20220924 | 38 | - triton==2.0.0.dev20221105 |
@@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
20 | 20 | ||
21 | default_args = { | 21 | default_args = { |
22 | "model": None, | 22 | "model": None, |
23 | "scheduler": "dpmpp", | ||
24 | "precision": "fp32", | 23 | "precision": "fp32", |
25 | "ti_embeddings_dir": "embeddings_ti", | 24 | "ti_embeddings_dir": "embeddings_ti", |
26 | "output_dir": "output/inference", | 25 | "output_dir": "output/inference", |
@@ -29,6 +28,7 @@ default_args = { | |||
29 | 28 | ||
30 | 29 | ||
31 | default_cmds = { | 30 | default_cmds = { |
31 | "scheduler": "dpmpp", | ||
32 | "prompt": None, | 32 | "prompt": None, |
33 | "negative_prompt": None, | 33 | "negative_prompt": None, |
34 | "image": None, | 34 | "image": None, |
@@ -62,11 +62,6 @@ def create_args_parser(): | |||
62 | type=str, | 62 | type=str, |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
65 | "--scheduler", | ||
66 | type=str, | ||
67 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | ||
68 | ) | ||
69 | parser.add_argument( | ||
70 | "--precision", | 65 | "--precision", |
71 | type=str, | 66 | type=str, |
72 | choices=["fp32", "fp16", "bf16"], | 67 | choices=["fp32", "fp16", "bf16"], |
@@ -92,6 +87,11 @@ def create_cmd_parser(): | |||
92 | description="Simple example of a training script." | 87 | description="Simple example of a training script." |
93 | ) | 88 | ) |
94 | parser.add_argument( | 89 | parser.add_argument( |
90 | "--scheduler", | ||
91 | type=str, | ||
92 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | ||
93 | ) | ||
94 | parser.add_argument( | ||
95 | "--prompt", | 95 | "--prompt", |
96 | type=str, | 96 | type=str, |
97 | nargs="+", | 97 | nargs="+", |
@@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | |||
199 | print(f"Loaded {placeholder_token}") | 199 | print(f"Loaded {placeholder_token}") |
200 | 200 | ||
201 | 201 | ||
202 | def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | 202 | def create_pipeline(model, ti_embeddings_dir, dtype): |
203 | print("Loading Stable Diffusion pipeline...") | 203 | print("Loading Stable Diffusion pipeline...") |
204 | 204 | ||
205 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 205 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
206 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 206 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) |
207 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 207 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
208 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 208 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
209 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | ||
209 | 210 | ||
210 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) | 211 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) |
211 | 212 | ||
212 | if scheduler == "plms": | ||
213 | scheduler = PNDMScheduler( | ||
214 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
215 | ) | ||
216 | elif scheduler == "klms": | ||
217 | scheduler = LMSDiscreteScheduler( | ||
218 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
219 | ) | ||
220 | elif scheduler == "ddim": | ||
221 | scheduler = DDIMScheduler( | ||
222 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
223 | ) | ||
224 | elif scheduler == "dpmpp": | ||
225 | scheduler = DPMSolverMultistepScheduler( | ||
226 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
227 | ) | ||
228 | else: | ||
229 | scheduler = EulerAncestralDiscreteScheduler( | ||
230 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
231 | ) | ||
232 | |||
233 | pipeline = VlpnStableDiffusion( | 213 | pipeline = VlpnStableDiffusion( |
234 | text_encoder=text_encoder, | 214 | text_encoder=text_encoder, |
235 | vae=vae, | 215 | vae=vae, |
@@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args): | |||
264 | else: | 244 | else: |
265 | init_image = None | 245 | init_image = None |
266 | 246 | ||
247 | if args.scheduler == "plms": | ||
248 | pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) | ||
249 | elif args.scheduler == "klms": | ||
250 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
251 | elif args.scheduler == "ddim": | ||
252 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||
253 | elif args.scheduler == "dpmpp": | ||
254 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | ||
255 | elif args.scheduler == "euler_a": | ||
256 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
257 | |||
267 | with torch.autocast("cuda"), torch.inference_mode(): | 258 | with torch.autocast("cuda"), torch.inference_mode(): |
268 | for i in range(args.batch_num): | 259 | for i in range(args.batch_num): |
269 | pipeline.set_progress_bar_config( | 260 | pipeline.set_progress_bar_config( |
@@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd): | |||
331 | generate(self.output_dir, self.pipeline, args) | 322 | generate(self.output_dir, self.pipeline, args) |
332 | except KeyboardInterrupt: | 323 | except KeyboardInterrupt: |
333 | print('Generation cancelled.') | 324 | print('Generation cancelled.') |
325 | except Exception as e: | ||
326 | print(e) | ||
327 | return | ||
334 | 328 | ||
335 | def do_exit(self, line): | 329 | def do_exit(self, line): |
336 | return True | 330 | return True |
@@ -345,7 +339,7 @@ def main(): | |||
345 | output_dir = Path(args.output_dir) | 339 | output_dir = Path(args.output_dir) |
346 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 340 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
347 | 341 | ||
348 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) | 342 | pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) |
349 | cmd_parser = create_cmd_parser() | 343 | cmd_parser = create_cmd_parser() |
350 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 344 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
351 | cmd_prompt.cmdloop() | 345 | cmd_prompt.cmdloop() |