diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3dd0920..31dbea2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -32,6 +32,7 @@ logger = get_logger(__name__) | |||
32 | 32 | ||
33 | 33 | ||
34 | torch.backends.cuda.matmul.allow_tf32 = True | 34 | torch.backends.cuda.matmul.allow_tf32 = True |
35 | torch.backends.cudnn.benchmark = True | ||
35 | 36 | ||
36 | 37 | ||
37 | def parse_args(): | 38 | def parse_args(): |
@@ -474,7 +475,6 @@ class Checkpointer: | |||
474 | scheduler=self.scheduler, | 475 | scheduler=self.scheduler, |
475 | ).to(self.accelerator.device) | 476 | ).to(self.accelerator.device) |
476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 477 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
477 | pipeline.enable_vae_slicing() | ||
478 | 478 | ||
479 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
480 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
@@ -550,6 +550,12 @@ class Checkpointer: | |||
550 | def main(): | 550 | def main(): |
551 | args = parse_args() | 551 | args = parse_args() |
552 | 552 | ||
553 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
554 | raise ValueError( | ||
555 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
556 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
557 | ) | ||
558 | |||
553 | instance_identifier = args.instance_identifier | 559 | instance_identifier = args.instance_identifier |
554 | 560 | ||
555 | if len(args.placeholder_token) != 0: | 561 | if len(args.placeholder_token) != 0: |
@@ -587,6 +593,7 @@ def main(): | |||
587 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 593 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
588 | args.pretrained_model_name_or_path, subfolder='scheduler') | 594 | args.pretrained_model_name_or_path, subfolder='scheduler') |
589 | 595 | ||
596 | vae.enable_slicing() | ||
590 | unet.set_use_memory_efficient_attention_xformers(True) | 597 | unet.set_use_memory_efficient_attention_xformers(True) |
591 | 598 | ||
592 | if args.gradient_checkpointing: | 599 | if args.gradient_checkpointing: |
@@ -903,7 +910,7 @@ def main(): | |||
903 | sample_checkpoint = False | 910 | sample_checkpoint = False |
904 | 911 | ||
905 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
906 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 913 | with accelerator.accumulate(unet): |
907 | # Convert images to latent space | 914 | # Convert images to latent space |
908 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 915 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
909 | latents = latents * 0.18215 | 916 | latents = latents * 0.18215 |