summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py11
-rw-r--r--infer.py3
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py27
-rw-r--r--textual_inversion.py1
4 files changed, 23 insertions, 19 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 3dd0920..31dbea2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -32,6 +32,7 @@ logger = get_logger(__name__)
32 32
33 33
34torch.backends.cuda.matmul.allow_tf32 = True 34torch.backends.cuda.matmul.allow_tf32 = True
35torch.backends.cudnn.benchmark = True
35 36
36 37
37def parse_args(): 38def parse_args():
@@ -474,7 +475,6 @@ class Checkpointer:
474 scheduler=self.scheduler, 475 scheduler=self.scheduler,
475 ).to(self.accelerator.device) 476 ).to(self.accelerator.device)
476 pipeline.set_progress_bar_config(dynamic_ncols=True) 477 pipeline.set_progress_bar_config(dynamic_ncols=True)
477 pipeline.enable_vae_slicing()
478 478
479 train_data = self.datamodule.train_dataloader() 479 train_data = self.datamodule.train_dataloader()
480 val_data = self.datamodule.val_dataloader() 480 val_data = self.datamodule.val_dataloader()
@@ -550,6 +550,12 @@ class Checkpointer:
550def main(): 550def main():
551 args = parse_args() 551 args = parse_args()
552 552
553 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
554 raise ValueError(
555 "Gradient accumulation is not supported when training the text encoder in distributed training. "
556 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
557 )
558
553 instance_identifier = args.instance_identifier 559 instance_identifier = args.instance_identifier
554 560
555 if len(args.placeholder_token) != 0: 561 if len(args.placeholder_token) != 0:
@@ -587,6 +593,7 @@ def main():
587 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 593 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
588 args.pretrained_model_name_or_path, subfolder='scheduler') 594 args.pretrained_model_name_or_path, subfolder='scheduler')
589 595
596 vae.enable_slicing()
590 unet.set_use_memory_efficient_attention_xformers(True) 597 unet.set_use_memory_efficient_attention_xformers(True)
591 598
592 if args.gradient_checkpointing: 599 if args.gradient_checkpointing:
@@ -903,7 +910,7 @@ def main():
903 sample_checkpoint = False 910 sample_checkpoint = False
904 911
905 for step, batch in enumerate(train_dataloader): 912 for step, batch in enumerate(train_dataloader):
906 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 913 with accelerator.accumulate(unet):
907 # Convert images to latent space 914 # Convert images to latent space
908 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 915 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
909 latents = latents * 0.18215 916 latents = latents * 0.18215
diff --git a/infer.py b/infer.py
index ab5f247..eabeb5e 100644
--- a/infer.py
+++ b/infer.py
@@ -16,6 +16,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16 16
17 17
18torch.backends.cuda.matmul.allow_tf32 = True 18torch.backends.cuda.matmul.allow_tf32 = True
19torch.backends.cudnn.benchmark = True
19 20
20 21
21default_args = { 22default_args = {
@@ -37,7 +38,7 @@ default_cmds = {
37 "height": 512, 38 "height": 512,
38 "batch_size": 1, 39 "batch_size": 1,
39 "batch_num": 1, 40 "batch_num": 1,
40 "steps": 50, 41 "steps": 30,
41 "guidance_scale": 7.0, 42 "guidance_scale": 7.0,
42 "seed": None, 43 "seed": None,
43 "config": None, 44 "config": None,
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index c77c4d1..9b51763 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,7 +20,6 @@ from diffusers import (
20 PNDMScheduler, 20 PNDMScheduler,
21) 21)
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.models.vae import DecoderOutput
24from diffusers.utils import logging 23from diffusers.utils import logging
25from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
26from models.clip.prompt import PromptProcessor 25from models.clip.prompt import PromptProcessor
@@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
70 scheduler._internal_dict = FrozenDict(new_config) 69 scheduler._internal_dict = FrozenDict(new_config)
71 70
72 self.prompt_processor = PromptProcessor(tokenizer, text_encoder) 71 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
73 self.use_slicing = False
74 72
75 self.register_modules( 73 self.register_modules(
76 vae=vae, 74 vae=vae,
@@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
108 `attention_head_dim` must be a multiple of `slice_size`. 106 `attention_head_dim` must be a multiple of `slice_size`.
109 """ 107 """
110 if slice_size == "auto": 108 if slice_size == "auto":
111 # half the attention head size is usually a good trade-off between 109 if isinstance(self.unet.config.attention_head_dim, int):
112 # speed and memory 110 # half the attention head size is usually a good trade-off between
113 slice_size = self.unet.config.attention_head_dim // 2 111 # speed and memory
112 slice_size = self.unet.config.attention_head_dim // 2
113 else:
114 # if `attention_head_dim` is a list, take the smallest head size
115 slice_size = min(self.unet.config.attention_head_dim)
116
114 self.unet.set_attention_slice(slice_size) 117 self.unet.set_attention_slice(slice_size)
115 118
116 def disable_attention_slicing(self): 119 def disable_attention_slicing(self):
@@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
144 When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several 147 When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
145 steps. This is useful to save some memory and allow larger batch sizes. 148 steps. This is useful to save some memory and allow larger batch sizes.
146 """ 149 """
147 self.use_slicing = True 150 self.vae.enable_slicing()
148 151
149 def disable_vae_slicing(self): 152 def disable_vae_slicing(self):
150 r""" 153 r"""
151 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to 154 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
152 computing decoding in one step. 155 computing decoding in one step.
153 """ 156 """
154 self.use_slicing = False 157 self.vae.disable_slicing()
155 158
156 @property 159 @property
157 def execution_device(self): 160 def execution_device(self):
@@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
297 300
298 def decode_latents(self, latents): 301 def decode_latents(self, latents):
299 latents = 1 / 0.18215 * latents 302 latents = 1 / 0.18215 * latents
300 image = self.vae_decode(latents).sample 303 image = self.vae.decode(latents).sample
301 image = (image / 2 + 0.5).clamp(0, 1) 304 image = (image / 2 + 0.5).clamp(0, 1)
302 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 305 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
303 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 306 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
304 return image 307 return image
305 308
306 def vae_decode(self, latents):
307 if self.use_slicing:
308 decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)]
309 decoded = torch.cat(decoded_slices)
310 return DecoderOutput(sample=decoded)
311 else:
312 return self.vae.decode(latents)
313
314 @torch.no_grad() 309 @torch.no_grad()
315 def __call__( 310 def __call__(
316 self, 311 self,
diff --git a/textual_inversion.py b/textual_inversion.py
index 7ac9638..d6be522 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -31,6 +31,7 @@ logger = get_logger(__name__)
31 31
32 32
33torch.backends.cuda.matmul.allow_tf32 = True 33torch.backends.cuda.matmul.allow_tf32 = True
34torch.backends.cudnn.benchmark = True
34 35
35 36
36def parse_args(): 37def parse_args():