diff options
| -rw-r--r-- | dreambooth.py | 11 | ||||
| -rw-r--r-- | infer.py | 3 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 | ||||
| -rw-r--r-- | textual_inversion.py | 1 |
4 files changed, 23 insertions, 19 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3dd0920..31dbea2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -32,6 +32,7 @@ logger = get_logger(__name__) | |||
| 32 | 32 | ||
| 33 | 33 | ||
| 34 | torch.backends.cuda.matmul.allow_tf32 = True | 34 | torch.backends.cuda.matmul.allow_tf32 = True |
| 35 | torch.backends.cudnn.benchmark = True | ||
| 35 | 36 | ||
| 36 | 37 | ||
| 37 | def parse_args(): | 38 | def parse_args(): |
| @@ -474,7 +475,6 @@ class Checkpointer: | |||
| 474 | scheduler=self.scheduler, | 475 | scheduler=self.scheduler, |
| 475 | ).to(self.accelerator.device) | 476 | ).to(self.accelerator.device) |
| 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 477 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 477 | pipeline.enable_vae_slicing() | ||
| 478 | 478 | ||
| 479 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
| 480 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
| @@ -550,6 +550,12 @@ class Checkpointer: | |||
| 550 | def main(): | 550 | def main(): |
| 551 | args = parse_args() | 551 | args = parse_args() |
| 552 | 552 | ||
| 553 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
| 554 | raise ValueError( | ||
| 555 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
| 556 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
| 557 | ) | ||
| 558 | |||
| 553 | instance_identifier = args.instance_identifier | 559 | instance_identifier = args.instance_identifier |
| 554 | 560 | ||
| 555 | if len(args.placeholder_token) != 0: | 561 | if len(args.placeholder_token) != 0: |
| @@ -587,6 +593,7 @@ def main(): | |||
| 587 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 593 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 588 | args.pretrained_model_name_or_path, subfolder='scheduler') | 594 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 589 | 595 | ||
| 596 | vae.enable_slicing() | ||
| 590 | unet.set_use_memory_efficient_attention_xformers(True) | 597 | unet.set_use_memory_efficient_attention_xformers(True) |
| 591 | 598 | ||
| 592 | if args.gradient_checkpointing: | 599 | if args.gradient_checkpointing: |
| @@ -903,7 +910,7 @@ def main(): | |||
| 903 | sample_checkpoint = False | 910 | sample_checkpoint = False |
| 904 | 911 | ||
| 905 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
| 906 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 913 | with accelerator.accumulate(unet): |
| 907 | # Convert images to latent space | 914 | # Convert images to latent space |
| 908 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 915 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 909 | latents = latents * 0.18215 | 916 | latents = latents * 0.18215 |
| @@ -16,6 +16,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | torch.backends.cuda.matmul.allow_tf32 = True | 18 | torch.backends.cuda.matmul.allow_tf32 = True |
| 19 | torch.backends.cudnn.benchmark = True | ||
| 19 | 20 | ||
| 20 | 21 | ||
| 21 | default_args = { | 22 | default_args = { |
| @@ -37,7 +38,7 @@ default_cmds = { | |||
| 37 | "height": 512, | 38 | "height": 512, |
| 38 | "batch_size": 1, | 39 | "batch_size": 1, |
| 39 | "batch_num": 1, | 40 | "batch_num": 1, |
| 40 | "steps": 50, | 41 | "steps": 30, |
| 41 | "guidance_scale": 7.0, | 42 | "guidance_scale": 7.0, |
| 42 | "seed": None, | 43 | "seed": None, |
| 43 | "config": None, | 44 | "config": None, |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index c77c4d1..9b51763 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -20,7 +20,6 @@ from diffusers import ( | |||
| 20 | PNDMScheduler, | 20 | PNDMScheduler, |
| 21 | ) | 21 | ) |
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.models.vae import DecoderOutput | ||
| 24 | from diffusers.utils import logging | 23 | from diffusers.utils import logging |
| 25 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 26 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
| @@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 70 | scheduler._internal_dict = FrozenDict(new_config) | 69 | scheduler._internal_dict = FrozenDict(new_config) |
| 71 | 70 | ||
| 72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 73 | self.use_slicing = False | ||
| 74 | 72 | ||
| 75 | self.register_modules( | 73 | self.register_modules( |
| 76 | vae=vae, | 74 | vae=vae, |
| @@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 108 | `attention_head_dim` must be a multiple of `slice_size`. | 106 | `attention_head_dim` must be a multiple of `slice_size`. |
| 109 | """ | 107 | """ |
| 110 | if slice_size == "auto": | 108 | if slice_size == "auto": |
| 111 | # half the attention head size is usually a good trade-off between | 109 | if isinstance(self.unet.config.attention_head_dim, int): |
| 112 | # speed and memory | 110 | # half the attention head size is usually a good trade-off between |
| 113 | slice_size = self.unet.config.attention_head_dim // 2 | 111 | # speed and memory |
| 112 | slice_size = self.unet.config.attention_head_dim // 2 | ||
| 113 | else: | ||
| 114 | # if `attention_head_dim` is a list, take the smallest head size | ||
| 115 | slice_size = min(self.unet.config.attention_head_dim) | ||
| 116 | |||
| 114 | self.unet.set_attention_slice(slice_size) | 117 | self.unet.set_attention_slice(slice_size) |
| 115 | 118 | ||
| 116 | def disable_attention_slicing(self): | 119 | def disable_attention_slicing(self): |
| @@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | 147 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several |
| 145 | steps. This is useful to save some memory and allow larger batch sizes. | 148 | steps. This is useful to save some memory and allow larger batch sizes. |
| 146 | """ | 149 | """ |
| 147 | self.use_slicing = True | 150 | self.vae.enable_slicing() |
| 148 | 151 | ||
| 149 | def disable_vae_slicing(self): | 152 | def disable_vae_slicing(self): |
| 150 | r""" | 153 | r""" |
| 151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | 154 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to |
| 152 | computing decoding in one step. | 155 | computing decoding in one step. |
| 153 | """ | 156 | """ |
| 154 | self.use_slicing = False | 157 | self.vae.disable_slicing() |
| 155 | 158 | ||
| 156 | @property | 159 | @property |
| 157 | def execution_device(self): | 160 | def execution_device(self): |
| @@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 297 | 300 | ||
| 298 | def decode_latents(self, latents): | 301 | def decode_latents(self, latents): |
| 299 | latents = 1 / 0.18215 * latents | 302 | latents = 1 / 0.18215 * latents |
| 300 | image = self.vae_decode(latents).sample | 303 | image = self.vae.decode(latents).sample |
| 301 | image = (image / 2 + 0.5).clamp(0, 1) | 304 | image = (image / 2 + 0.5).clamp(0, 1) |
| 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 305 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
| 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 306 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 304 | return image | 307 | return image |
| 305 | 308 | ||
| 306 | def vae_decode(self, latents): | ||
| 307 | if self.use_slicing: | ||
| 308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
| 309 | decoded = torch.cat(decoded_slices) | ||
| 310 | return DecoderOutput(sample=decoded) | ||
| 311 | else: | ||
| 312 | return self.vae.decode(latents) | ||
| 313 | |||
| 314 | @torch.no_grad() | 309 | @torch.no_grad() |
| 315 | def __call__( | 310 | def __call__( |
| 316 | self, | 311 | self, |
diff --git a/textual_inversion.py b/textual_inversion.py index 7ac9638..d6be522 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -31,6 +31,7 @@ logger = get_logger(__name__) | |||
| 31 | 31 | ||
| 32 | 32 | ||
| 33 | torch.backends.cuda.matmul.allow_tf32 = True | 33 | torch.backends.cuda.matmul.allow_tf32 = True |
| 34 | torch.backends.cudnn.benchmark = True | ||
| 34 | 35 | ||
| 35 | 36 | ||
| 36 | def parse_args(): | 37 | def parse_args(): |
