summaryrefslogtreecommitdiffstats
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/main.py b/main.py
index aa5af72..8be79e5 100644
--- a/main.py
+++ b/main.py
@@ -106,6 +106,11 @@ def parse_args():
106 help="Number of updates steps to accumulate before performing a backward/update pass.", 106 help="Number of updates steps to accumulate before performing a backward/update pass.",
107 ) 107 )
108 parser.add_argument( 108 parser.add_argument(
109 "--gradient_checkpointing",
110 action="store_true",
111 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
112 )
113 parser.add_argument(
109 "--learning_rate", 114 "--learning_rate",
110 type=float, 115 type=float,
111 default=1e-4, 116 default=1e-4,
@@ -213,7 +218,7 @@ def parse_args():
213 elif args.config is not None: 218 elif args.config is not None:
214 with open(args.config, 'rt') as f: 219 with open(args.config, 'rt') as f:
215 args = parser.parse_args( 220 args = parser.parse_args(
216 namespace=argparse.Namespace(**json.load(f))) 221 namespace=argparse.Namespace(**json.load(f)["args"]))
217 222
218 env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 223 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
219 if env_local_rank != -1 and env_local_rank != args.local_rank: 224 if env_local_rank != -1 and env_local_rank != args.local_rank:
@@ -289,7 +294,7 @@ class Checkpointer:
289 self.stable_sample_batches = stable_sample_batches 294 self.stable_sample_batches = stable_sample_batches
290 295
291 @torch.no_grad() 296 @torch.no_grad()
292 def checkpoint(self, step, text_encoder, save_samples=True, path=None): 297 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
293 print("Saving checkpoint for step %d..." % step) 298 print("Saving checkpoint for step %d..." % step)
294 with self.accelerator.autocast(): 299 with self.accelerator.autocast():
295 if path is None: 300 if path is None:
@@ -302,12 +307,11 @@ class Checkpointer:
302 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 307 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
303 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 308 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
304 309
305 filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) 310 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
306 if path is not None: 311 if path is not None:
307 torch.save(learned_embeds_dict, path) 312 torch.save(learned_embeds_dict, path)
308 else: 313 else:
309 torch.save(learned_embeds_dict, 314 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
310 f"{checkpoints_path}/{filename}")
311 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") 315 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
312 del unwrapped 316 del unwrapped
313 del learned_embeds 317 del learned_embeds
@@ -338,7 +342,7 @@ class Checkpointer:
338 "validation": self.datamodule.val_dataloader(), 342 "validation": self.datamodule.val_dataloader(),
339 }[mode] 343 }[mode]
340 344
341 if mode == "validation" and self.stable_sample_batches > 0: 345 if mode == "validation" and self.stable_sample_batches > 0 and step > 0:
342 stable_latents = torch.randn( 346 stable_latents = torch.randn(
343 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 347 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
344 device=pipeline.device, 348 device=pipeline.device,
@@ -348,15 +352,18 @@ class Checkpointer:
348 all_samples = [] 352 all_samples = []
349 filename = f"stable_step_%d.png" % (step) 353 filename = f"stable_step_%d.png" % (step)
350 354
355 data_enum = enumerate(data)
356
351 # Generate and save stable samples 357 # Generate and save stable samples
352 for i in range(0, self.stable_sample_batches): 358 for i in range(0, self.stable_sample_batches):
353 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] 359 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
360 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
354 361
355 with self.accelerator.autocast(): 362 with self.accelerator.autocast():
356 samples = pipeline( 363 samples = pipeline(
357 prompt=prompt, 364 prompt=prompt,
358 height=self.sample_image_size, 365 height=self.sample_image_size,
359 latents=stable_latents, 366 latents=stable_latents[:len(prompt)],
360 width=self.sample_image_size, 367 width=self.sample_image_size,
361 guidance_scale=guidance_scale, 368 guidance_scale=guidance_scale,
362 eta=eta, 369 eta=eta,
@@ -377,9 +384,12 @@ class Checkpointer:
377 all_samples = [] 384 all_samples = []
378 filename = f"step_%d.png" % (step) 385 filename = f"step_%d.png" % (step)
379 386
387 data_enum = enumerate(data)
388
380 # Generate and save random samples 389 # Generate and save random samples
381 for i in range(0, self.random_sample_batches): 390 for i in range(0, self.random_sample_batches):
382 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] 391 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
392 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
383 393
384 with self.accelerator.autocast(): 394 with self.accelerator.autocast():
385 samples = pipeline( 395 samples = pipeline(
@@ -486,6 +496,9 @@ def main():
486 args.pretrained_model_name_or_path + '/unet', 496 args.pretrained_model_name_or_path + '/unet',
487 ) 497 )
488 498
499 if args.gradient_checkpointing:
500 unet.enable_gradient_checkpointing()
501
489 slice_size = unet.config.attention_head_dim // 2 502 slice_size = unet.config.attention_head_dim // 2
490 unet.set_attention_slice(slice_size) 503 unet.set_attention_slice(slice_size)
491 504
@@ -693,7 +706,7 @@ def main():
693 progress_bar.clear() 706 progress_bar.clear()
694 local_progress_bar.clear() 707 local_progress_bar.clear()
695 708
696 checkpointer.checkpoint(global_step + global_step_offset, text_encoder) 709 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
697 save_resume_file(basepath, args, { 710 save_resume_file(basepath, args, {
698 "global_step": global_step + global_step_offset, 711 "global_step": global_step + global_step_offset,
699 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 712 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
@@ -753,6 +766,7 @@ def main():
753 766
754 if min_val_loss > val_loss: 767 if min_val_loss > val_loss:
755 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 768 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
769 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
756 min_val_loss = val_loss 770 min_val_loss = val_loss
757 771
758 checkpointer.save_samples( 772 checkpointer.save_samples(
@@ -768,6 +782,7 @@ def main():
768 print("Finished! Saving final checkpoint and resume state.") 782 print("Finished! Saving final checkpoint and resume state.")
769 checkpointer.checkpoint( 783 checkpointer.checkpoint(
770 global_step + global_step_offset, 784 global_step + global_step_offset,
785 "end",
771 text_encoder, 786 text_encoder,
772 path=f"{basepath}/learned_embeds.bin" 787 path=f"{basepath}/learned_embeds.bin"
773 ) 788 )
@@ -782,7 +797,7 @@ def main():
782 except KeyboardInterrupt: 797 except KeyboardInterrupt:
783 if accelerator.is_main_process: 798 if accelerator.is_main_process:
784 print("Interrupted, saving checkpoint and resume state...") 799 print("Interrupted, saving checkpoint and resume state...")
785 checkpointer.checkpoint(global_step + global_step_offset, text_encoder) 800 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
786 save_resume_file(basepath, args, { 801 save_resume_file(basepath, args, {
787 "global_step": global_step + global_step_offset, 802 "global_step": global_step + global_step_offset,
788 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 803 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"