summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 3dd0920..31dbea2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -32,6 +32,7 @@ logger = get_logger(__name__)
32 32
33 33
34torch.backends.cuda.matmul.allow_tf32 = True 34torch.backends.cuda.matmul.allow_tf32 = True
35torch.backends.cudnn.benchmark = True
35 36
36 37
37def parse_args(): 38def parse_args():
@@ -474,7 +475,6 @@ class Checkpointer:
474 scheduler=self.scheduler, 475 scheduler=self.scheduler,
475 ).to(self.accelerator.device) 476 ).to(self.accelerator.device)
476 pipeline.set_progress_bar_config(dynamic_ncols=True) 477 pipeline.set_progress_bar_config(dynamic_ncols=True)
477 pipeline.enable_vae_slicing()
478 478
479 train_data = self.datamodule.train_dataloader() 479 train_data = self.datamodule.train_dataloader()
480 val_data = self.datamodule.val_dataloader() 480 val_data = self.datamodule.val_dataloader()
@@ -550,6 +550,12 @@ class Checkpointer:
550def main(): 550def main():
551 args = parse_args() 551 args = parse_args()
552 552
553 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
554 raise ValueError(
555 "Gradient accumulation is not supported when training the text encoder in distributed training. "
556 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
557 )
558
553 instance_identifier = args.instance_identifier 559 instance_identifier = args.instance_identifier
554 560
555 if len(args.placeholder_token) != 0: 561 if len(args.placeholder_token) != 0:
@@ -587,6 +593,7 @@ def main():
587 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 593 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
588 args.pretrained_model_name_or_path, subfolder='scheduler') 594 args.pretrained_model_name_or_path, subfolder='scheduler')
589 595
596 vae.enable_slicing()
590 unet.set_use_memory_efficient_attention_xformers(True) 597 unet.set_use_memory_efficient_attention_xformers(True)
591 598
592 if args.gradient_checkpointing: 599 if args.gradient_checkpointing:
@@ -903,7 +910,7 @@ def main():
903 sample_checkpoint = False 910 sample_checkpoint = False
904 911
905 for step, batch in enumerate(train_dataloader): 912 for step, batch in enumerate(train_dataloader):
906 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 913 with accelerator.accumulate(unet):
907 # Convert images to latent space 914 # Convert images to latent space
908 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 915 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
909 latents = latents * 0.18215 916 latents = latents * 0.18215