summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py50
1 files changed, 44 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9f1b7af..f7d31d2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -16,6 +16,7 @@ from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
19from diffusers.training_utils import EMAModel
19from PIL import Image 20from PIL import Image
20from tqdm.auto import tqdm 21from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
@@ -111,7 +112,7 @@ def parse_args():
111 parser.add_argument( 112 parser.add_argument(
112 "--max_train_steps", 113 "--max_train_steps",
113 type=int, 114 type=int,
114 default=600, 115 default=5000,
115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
116 ) 117 )
117 parser.add_argument( 118 parser.add_argument(
@@ -128,7 +129,7 @@ def parse_args():
128 parser.add_argument( 129 parser.add_argument(
129 "--learning_rate", 130 "--learning_rate",
130 type=float, 131 type=float,
131 default=1e-7, 132 default=1e-4,
132 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
133 ) 134 )
134 parser.add_argument( 135 parser.add_argument(
@@ -140,7 +141,7 @@ def parse_args():
140 parser.add_argument( 141 parser.add_argument(
141 "--lr_scheduler", 142 "--lr_scheduler",
142 type=str, 143 type=str,
143 default="linear", 144 default="cosine",
144 help=( 145 help=(
145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 146 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
146 ' "constant", "constant_with_warmup"]' 147 ' "constant", "constant_with_warmup"]'
@@ -153,8 +154,30 @@ def parse_args():
153 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
154 ) 155 )
155 parser.add_argument( 156 parser.add_argument(
157 "--use_ema",
158 action="store_true",
159 default=True,
160 help="Whether to use EMA model."
161 )
162 parser.add_argument(
163 "--ema_inv_gamma",
164 type=float,
165 default=0.1
166 )
167 parser.add_argument(
168 "--ema_power",
169 type=float,
170 default=1
171 )
172 parser.add_argument(
173 "--ema_max_decay",
174 type=float,
175 default=0.9999
176 )
177 parser.add_argument(
156 "--use_8bit_adam", 178 "--use_8bit_adam",
157 action="store_true", 179 action="store_true",
180 default=True,
158 help="Whether or not to use 8-bit Adam from bitsandbytes." 181 help="Whether or not to use 8-bit Adam from bitsandbytes."
159 ) 182 )
160 parser.add_argument( 183 parser.add_argument(
@@ -172,7 +195,7 @@ def parse_args():
172 parser.add_argument( 195 parser.add_argument(
173 "--adam_weight_decay", 196 "--adam_weight_decay",
174 type=float, 197 type=float,
175 default=0, 198 default=1e-2,
176 help="Weight decay to use." 199 help="Weight decay to use."
177 ) 200 )
178 parser.add_argument( 201 parser.add_argument(
@@ -298,6 +321,7 @@ class Checkpointer:
298 accelerator, 321 accelerator,
299 vae, 322 vae,
300 unet, 323 unet,
324 ema_unet,
301 tokenizer, 325 tokenizer,
302 text_encoder, 326 text_encoder,
303 output_dir: Path, 327 output_dir: Path,
@@ -311,6 +335,7 @@ class Checkpointer:
311 self.accelerator = accelerator 335 self.accelerator = accelerator
312 self.vae = vae 336 self.vae = vae
313 self.unet = unet 337 self.unet = unet
338 self.ema_unet = ema_unet
314 self.tokenizer = tokenizer 339 self.tokenizer = tokenizer
315 self.text_encoder = text_encoder 340 self.text_encoder = text_encoder
316 self.output_dir = output_dir 341 self.output_dir = output_dir
@@ -324,7 +349,8 @@ class Checkpointer:
324 def checkpoint(self): 349 def checkpoint(self):
325 print("Saving model...") 350 print("Saving model...")
326 351
327 unwrapped = self.accelerator.unwrap_model(self.unet) 352 unwrapped = self.accelerator.unwrap_model(
353 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
328 pipeline = VlpnStableDiffusion( 354 pipeline = VlpnStableDiffusion(
329 text_encoder=self.text_encoder, 355 text_encoder=self.text_encoder,
330 vae=self.vae, 356 vae=self.vae,
@@ -346,7 +372,8 @@ class Checkpointer:
346 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 372 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
347 samples_path = Path(self.output_dir).joinpath("samples") 373 samples_path = Path(self.output_dir).joinpath("samples")
348 374
349 unwrapped = self.accelerator.unwrap_model(self.unet) 375 unwrapped = self.accelerator.unwrap_model(
376 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
350 scheduler = EulerAScheduler( 377 scheduler = EulerAScheduler(
351 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 378 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
352 ) 379 )
@@ -602,6 +629,13 @@ def main():
602 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 629 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
603 ) 630 )
604 631
632 ema_unet = EMAModel(
633 unet,
634 inv_gamma=args.ema_inv_gamma,
635 power=args.ema_power,
636 max_value=args.ema_max_decay
637 ) if args.use_ema else None
638
605 # Move text_encoder and vae to device 639 # Move text_encoder and vae to device
606 text_encoder.to(accelerator.device) 640 text_encoder.to(accelerator.device)
607 vae.to(accelerator.device) 641 vae.to(accelerator.device)
@@ -643,6 +677,7 @@ def main():
643 accelerator=accelerator, 677 accelerator=accelerator,
644 vae=vae, 678 vae=vae,
645 unet=unet, 679 unet=unet,
680 ema_unet=ema_unet,
646 tokenizer=tokenizer, 681 tokenizer=tokenizer,
647 text_encoder=text_encoder, 682 text_encoder=text_encoder,
648 output_dir=basepath, 683 output_dir=basepath,
@@ -737,6 +772,9 @@ def main():
737 772
738 # Checks if the accelerator has performed an optimization step behind the scenes 773 # Checks if the accelerator has performed an optimization step behind the scenes
739 if accelerator.sync_gradients: 774 if accelerator.sync_gradients:
775 if args.use_ema:
776 ema_unet.step(unet)
777
740 local_progress_bar.update(1) 778 local_progress_bar.update(1)
741 global_progress_bar.update(1) 779 global_progress_bar.update(1)
742 780