diff options
-rw-r--r-- | dreambooth.py | 11 | ||||
-rw-r--r-- | infer.py | 3 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 | ||||
-rw-r--r-- | textual_inversion.py | 1 |
4 files changed, 23 insertions, 19 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3dd0920..31dbea2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -32,6 +32,7 @@ logger = get_logger(__name__) | |||
32 | 32 | ||
33 | 33 | ||
34 | torch.backends.cuda.matmul.allow_tf32 = True | 34 | torch.backends.cuda.matmul.allow_tf32 = True |
35 | torch.backends.cudnn.benchmark = True | ||
35 | 36 | ||
36 | 37 | ||
37 | def parse_args(): | 38 | def parse_args(): |
@@ -474,7 +475,6 @@ class Checkpointer: | |||
474 | scheduler=self.scheduler, | 475 | scheduler=self.scheduler, |
475 | ).to(self.accelerator.device) | 476 | ).to(self.accelerator.device) |
476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 477 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
477 | pipeline.enable_vae_slicing() | ||
478 | 478 | ||
479 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
480 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
@@ -550,6 +550,12 @@ class Checkpointer: | |||
550 | def main(): | 550 | def main(): |
551 | args = parse_args() | 551 | args = parse_args() |
552 | 552 | ||
553 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
554 | raise ValueError( | ||
555 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
556 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
557 | ) | ||
558 | |||
553 | instance_identifier = args.instance_identifier | 559 | instance_identifier = args.instance_identifier |
554 | 560 | ||
555 | if len(args.placeholder_token) != 0: | 561 | if len(args.placeholder_token) != 0: |
@@ -587,6 +593,7 @@ def main(): | |||
587 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 593 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
588 | args.pretrained_model_name_or_path, subfolder='scheduler') | 594 | args.pretrained_model_name_or_path, subfolder='scheduler') |
589 | 595 | ||
596 | vae.enable_slicing() | ||
590 | unet.set_use_memory_efficient_attention_xformers(True) | 597 | unet.set_use_memory_efficient_attention_xformers(True) |
591 | 598 | ||
592 | if args.gradient_checkpointing: | 599 | if args.gradient_checkpointing: |
@@ -903,7 +910,7 @@ def main(): | |||
903 | sample_checkpoint = False | 910 | sample_checkpoint = False |
904 | 911 | ||
905 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
906 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 913 | with accelerator.accumulate(unet): |
907 | # Convert images to latent space | 914 | # Convert images to latent space |
908 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 915 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
909 | latents = latents * 0.18215 | 916 | latents = latents * 0.18215 |
@@ -16,6 +16,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
16 | 16 | ||
17 | 17 | ||
18 | torch.backends.cuda.matmul.allow_tf32 = True | 18 | torch.backends.cuda.matmul.allow_tf32 = True |
19 | torch.backends.cudnn.benchmark = True | ||
19 | 20 | ||
20 | 21 | ||
21 | default_args = { | 22 | default_args = { |
@@ -37,7 +38,7 @@ default_cmds = { | |||
37 | "height": 512, | 38 | "height": 512, |
38 | "batch_size": 1, | 39 | "batch_size": 1, |
39 | "batch_num": 1, | 40 | "batch_num": 1, |
40 | "steps": 50, | 41 | "steps": 30, |
41 | "guidance_scale": 7.0, | 42 | "guidance_scale": 7.0, |
42 | "seed": None, | 43 | "seed": None, |
43 | "config": None, | 44 | "config": None, |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index c77c4d1..9b51763 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -20,7 +20,6 @@ from diffusers import ( | |||
20 | PNDMScheduler, | 20 | PNDMScheduler, |
21 | ) | 21 | ) |
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.models.vae import DecoderOutput | ||
24 | from diffusers.utils import logging | 23 | from diffusers.utils import logging |
25 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
26 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
@@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
70 | scheduler._internal_dict = FrozenDict(new_config) | 69 | scheduler._internal_dict = FrozenDict(new_config) |
71 | 70 | ||
72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
73 | self.use_slicing = False | ||
74 | 72 | ||
75 | self.register_modules( | 73 | self.register_modules( |
76 | vae=vae, | 74 | vae=vae, |
@@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
108 | `attention_head_dim` must be a multiple of `slice_size`. | 106 | `attention_head_dim` must be a multiple of `slice_size`. |
109 | """ | 107 | """ |
110 | if slice_size == "auto": | 108 | if slice_size == "auto": |
111 | # half the attention head size is usually a good trade-off between | 109 | if isinstance(self.unet.config.attention_head_dim, int): |
112 | # speed and memory | 110 | # half the attention head size is usually a good trade-off between |
113 | slice_size = self.unet.config.attention_head_dim // 2 | 111 | # speed and memory |
112 | slice_size = self.unet.config.attention_head_dim // 2 | ||
113 | else: | ||
114 | # if `attention_head_dim` is a list, take the smallest head size | ||
115 | slice_size = min(self.unet.config.attention_head_dim) | ||
116 | |||
114 | self.unet.set_attention_slice(slice_size) | 117 | self.unet.set_attention_slice(slice_size) |
115 | 118 | ||
116 | def disable_attention_slicing(self): | 119 | def disable_attention_slicing(self): |
@@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | 147 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several |
145 | steps. This is useful to save some memory and allow larger batch sizes. | 148 | steps. This is useful to save some memory and allow larger batch sizes. |
146 | """ | 149 | """ |
147 | self.use_slicing = True | 150 | self.vae.enable_slicing() |
148 | 151 | ||
149 | def disable_vae_slicing(self): | 152 | def disable_vae_slicing(self): |
150 | r""" | 153 | r""" |
151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | 154 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to |
152 | computing decoding in one step. | 155 | computing decoding in one step. |
153 | """ | 156 | """ |
154 | self.use_slicing = False | 157 | self.vae.disable_slicing() |
155 | 158 | ||
156 | @property | 159 | @property |
157 | def execution_device(self): | 160 | def execution_device(self): |
@@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
297 | 300 | ||
298 | def decode_latents(self, latents): | 301 | def decode_latents(self, latents): |
299 | latents = 1 / 0.18215 * latents | 302 | latents = 1 / 0.18215 * latents |
300 | image = self.vae_decode(latents).sample | 303 | image = self.vae.decode(latents).sample |
301 | image = (image / 2 + 0.5).clamp(0, 1) | 304 | image = (image / 2 + 0.5).clamp(0, 1) |
302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 305 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 306 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
304 | return image | 307 | return image |
305 | 308 | ||
306 | def vae_decode(self, latents): | ||
307 | if self.use_slicing: | ||
308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
309 | decoded = torch.cat(decoded_slices) | ||
310 | return DecoderOutput(sample=decoded) | ||
311 | else: | ||
312 | return self.vae.decode(latents) | ||
313 | |||
314 | @torch.no_grad() | 309 | @torch.no_grad() |
315 | def __call__( | 310 | def __call__( |
316 | self, | 311 | self, |
diff --git a/textual_inversion.py b/textual_inversion.py index 7ac9638..d6be522 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -31,6 +31,7 @@ logger = get_logger(__name__) | |||
31 | 31 | ||
32 | 32 | ||
33 | torch.backends.cuda.matmul.allow_tf32 = True | 33 | torch.backends.cuda.matmul.allow_tf32 = True |
34 | torch.backends.cudnn.benchmark = True | ||
34 | 35 | ||
35 | 36 | ||
36 | def parse_args(): | 37 | def parse_args(): |