diff options
| -rw-r--r-- | dreambooth.py | 50 |
1 files changed, 44 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9f1b7af..f7d31d2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -16,6 +16,7 @@ from accelerate.logging import get_logger | |||
| 16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
| 18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 19 | from diffusers.training_utils import EMAModel | ||
| 19 | from PIL import Image | 20 | from PIL import Image |
| 20 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| @@ -111,7 +112,7 @@ def parse_args(): | |||
| 111 | parser.add_argument( | 112 | parser.add_argument( |
| 112 | "--max_train_steps", | 113 | "--max_train_steps", |
| 113 | type=int, | 114 | type=int, |
| 114 | default=600, | 115 | default=5000, |
| 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 116 | ) | 117 | ) |
| 117 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -128,7 +129,7 @@ def parse_args(): | |||
| 128 | parser.add_argument( | 129 | parser.add_argument( |
| 129 | "--learning_rate", | 130 | "--learning_rate", |
| 130 | type=float, | 131 | type=float, |
| 131 | default=1e-7, | 132 | default=1e-4, |
| 132 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 133 | ) | 134 | ) |
| 134 | parser.add_argument( | 135 | parser.add_argument( |
| @@ -140,7 +141,7 @@ def parse_args(): | |||
| 140 | parser.add_argument( | 141 | parser.add_argument( |
| 141 | "--lr_scheduler", | 142 | "--lr_scheduler", |
| 142 | type=str, | 143 | type=str, |
| 143 | default="linear", | 144 | default="cosine", |
| 144 | help=( | 145 | help=( |
| 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 146 | ' "constant", "constant_with_warmup"]' | 147 | ' "constant", "constant_with_warmup"]' |
| @@ -153,8 +154,30 @@ def parse_args(): | |||
| 153 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 154 | ) | 155 | ) |
| 155 | parser.add_argument( | 156 | parser.add_argument( |
| 157 | "--use_ema", | ||
| 158 | action="store_true", | ||
| 159 | default=True, | ||
| 160 | help="Whether to use EMA model." | ||
| 161 | ) | ||
| 162 | parser.add_argument( | ||
| 163 | "--ema_inv_gamma", | ||
| 164 | type=float, | ||
| 165 | default=0.1 | ||
| 166 | ) | ||
| 167 | parser.add_argument( | ||
| 168 | "--ema_power", | ||
| 169 | type=float, | ||
| 170 | default=1 | ||
| 171 | ) | ||
| 172 | parser.add_argument( | ||
| 173 | "--ema_max_decay", | ||
| 174 | type=float, | ||
| 175 | default=0.9999 | ||
| 176 | ) | ||
| 177 | parser.add_argument( | ||
| 156 | "--use_8bit_adam", | 178 | "--use_8bit_adam", |
| 157 | action="store_true", | 179 | action="store_true", |
| 180 | default=True, | ||
| 158 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 181 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| 159 | ) | 182 | ) |
| 160 | parser.add_argument( | 183 | parser.add_argument( |
| @@ -172,7 +195,7 @@ def parse_args(): | |||
| 172 | parser.add_argument( | 195 | parser.add_argument( |
| 173 | "--adam_weight_decay", | 196 | "--adam_weight_decay", |
| 174 | type=float, | 197 | type=float, |
| 175 | default=0, | 198 | default=1e-2, |
| 176 | help="Weight decay to use." | 199 | help="Weight decay to use." |
| 177 | ) | 200 | ) |
| 178 | parser.add_argument( | 201 | parser.add_argument( |
| @@ -298,6 +321,7 @@ class Checkpointer: | |||
| 298 | accelerator, | 321 | accelerator, |
| 299 | vae, | 322 | vae, |
| 300 | unet, | 323 | unet, |
| 324 | ema_unet, | ||
| 301 | tokenizer, | 325 | tokenizer, |
| 302 | text_encoder, | 326 | text_encoder, |
| 303 | output_dir: Path, | 327 | output_dir: Path, |
| @@ -311,6 +335,7 @@ class Checkpointer: | |||
| 311 | self.accelerator = accelerator | 335 | self.accelerator = accelerator |
| 312 | self.vae = vae | 336 | self.vae = vae |
| 313 | self.unet = unet | 337 | self.unet = unet |
| 338 | self.ema_unet = ema_unet | ||
| 314 | self.tokenizer = tokenizer | 339 | self.tokenizer = tokenizer |
| 315 | self.text_encoder = text_encoder | 340 | self.text_encoder = text_encoder |
| 316 | self.output_dir = output_dir | 341 | self.output_dir = output_dir |
| @@ -324,7 +349,8 @@ class Checkpointer: | |||
| 324 | def checkpoint(self): | 349 | def checkpoint(self): |
| 325 | print("Saving model...") | 350 | print("Saving model...") |
| 326 | 351 | ||
| 327 | unwrapped = self.accelerator.unwrap_model(self.unet) | 352 | unwrapped = self.accelerator.unwrap_model( |
| 353 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | ||
| 328 | pipeline = VlpnStableDiffusion( | 354 | pipeline = VlpnStableDiffusion( |
| 329 | text_encoder=self.text_encoder, | 355 | text_encoder=self.text_encoder, |
| 330 | vae=self.vae, | 356 | vae=self.vae, |
| @@ -346,7 +372,8 @@ class Checkpointer: | |||
| 346 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 372 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 347 | samples_path = Path(self.output_dir).joinpath("samples") | 373 | samples_path = Path(self.output_dir).joinpath("samples") |
| 348 | 374 | ||
| 349 | unwrapped = self.accelerator.unwrap_model(self.unet) | 375 | unwrapped = self.accelerator.unwrap_model( |
| 376 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | ||
| 350 | scheduler = EulerAScheduler( | 377 | scheduler = EulerAScheduler( |
| 351 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 378 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 352 | ) | 379 | ) |
| @@ -602,6 +629,13 @@ def main(): | |||
| 602 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 603 | ) | 630 | ) |
| 604 | 631 | ||
| 632 | ema_unet = EMAModel( | ||
| 633 | unet, | ||
| 634 | inv_gamma=args.ema_inv_gamma, | ||
| 635 | power=args.ema_power, | ||
| 636 | max_value=args.ema_max_decay | ||
| 637 | ) if args.use_ema else None | ||
| 638 | |||
| 605 | # Move text_encoder and vae to device | 639 | # Move text_encoder and vae to device |
| 606 | text_encoder.to(accelerator.device) | 640 | text_encoder.to(accelerator.device) |
| 607 | vae.to(accelerator.device) | 641 | vae.to(accelerator.device) |
| @@ -643,6 +677,7 @@ def main(): | |||
| 643 | accelerator=accelerator, | 677 | accelerator=accelerator, |
| 644 | vae=vae, | 678 | vae=vae, |
| 645 | unet=unet, | 679 | unet=unet, |
| 680 | ema_unet=ema_unet, | ||
| 646 | tokenizer=tokenizer, | 681 | tokenizer=tokenizer, |
| 647 | text_encoder=text_encoder, | 682 | text_encoder=text_encoder, |
| 648 | output_dir=basepath, | 683 | output_dir=basepath, |
| @@ -737,6 +772,9 @@ def main(): | |||
| 737 | 772 | ||
| 738 | # Checks if the accelerator has performed an optimization step behind the scenes | 773 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 739 | if accelerator.sync_gradients: | 774 | if accelerator.sync_gradients: |
| 775 | if args.use_ema: | ||
| 776 | ema_unet.step(unet) | ||
| 777 | |||
| 740 | local_progress_bar.update(1) | 778 | local_progress_bar.update(1) |
| 741 | global_progress_bar.update(1) | 779 | global_progress_bar.update(1) |
| 742 | 780 | ||
