diff options
| -rw-r--r-- | dreambooth.py | 23 | ||||
| -rw-r--r-- | textual_inversion.py | 54 |
2 files changed, 60 insertions, 17 deletions
diff --git a/dreambooth.py b/dreambooth.py index f7d31d2..02f83c6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -129,7 +129,7 @@ def parse_args(): | |||
| 129 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--learning_rate", | 130 | "--learning_rate", |
| 131 | type=float, | 131 | type=float, |
| 132 | default=1e-4, | 132 | default=1e-6, |
| 133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| @@ -150,7 +150,7 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 150 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
| 152 | type=int, | 152 | type=int, |
| 153 | default=200, | 153 | default=600, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 155 | ) |
| 156 | parser.add_argument( | 156 | parser.add_argument( |
| @@ -162,12 +162,12 @@ def parse_args(): | |||
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| 163 | "--ema_inv_gamma", | 163 | "--ema_inv_gamma", |
| 164 | type=float, | 164 | type=float, |
| 165 | default=0.1 | 165 | default=1.0 |
| 166 | ) | 166 | ) |
| 167 | parser.add_argument( | 167 | parser.add_argument( |
| 168 | "--ema_power", | 168 | "--ema_power", |
| 169 | type=float, | 169 | type=float, |
| 170 | default=1 | 170 | default=1.0 |
| 171 | ) | 171 | ) |
| 172 | parser.add_argument( | 172 | parser.add_argument( |
| 173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
| @@ -783,7 +783,12 @@ def main(): | |||
| 783 | if global_step % args.sample_frequency == 0: | 783 | if global_step % args.sample_frequency == 0: |
| 784 | sample_checkpoint = True | 784 | sample_checkpoint = True |
| 785 | 785 | ||
| 786 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 786 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 787 | if args.use_ema: | ||
| 788 | logs["ema_decay"] = ema_unet.decay | ||
| 789 | |||
| 790 | accelerator.log(logs, step=global_step) | ||
| 791 | |||
| 787 | local_progress_bar.set_postfix(**logs) | 792 | local_progress_bar.set_postfix(**logs) |
| 788 | 793 | ||
| 789 | if global_step >= args.max_train_steps: | 794 | if global_step >= args.max_train_steps: |
| @@ -824,16 +829,12 @@ def main(): | |||
| 824 | local_progress_bar.update(1) | 829 | local_progress_bar.update(1) |
| 825 | global_progress_bar.update(1) | 830 | global_progress_bar.update(1) |
| 826 | 831 | ||
| 827 | logs = {"mode": "validation", "loss": loss} | 832 | logs = {"val/loss": loss} |
| 828 | local_progress_bar.set_postfix(**logs) | 833 | local_progress_bar.set_postfix(**logs) |
| 829 | 834 | ||
| 830 | val_loss /= len(val_dataloader) | 835 | val_loss /= len(val_dataloader) |
| 831 | 836 | ||
| 832 | accelerator.log({ | 837 | accelerator.log({"val/loss": val_loss}, step=global_step) |
| 833 | "train/loss": train_loss, | ||
| 834 | "val/loss": val_loss, | ||
| 835 | "lr": lr_scheduler.get_last_lr()[0] | ||
| 836 | }, step=global_step) | ||
| 837 | 838 | ||
| 838 | local_progress_bar.clear() | 839 | local_progress_bar.clear() |
| 839 | global_progress_bar.clear() | 840 | global_progress_bar.clear() |
diff --git a/textual_inversion.py b/textual_inversion.py index b01bdbc..e6d856a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -17,6 +17,7 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 20 | from diffusers.training_utils import EMAModel | ||
| 20 | from PIL import Image | 21 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 22 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| @@ -149,10 +150,31 @@ def parse_args(): | |||
| 149 | parser.add_argument( | 150 | parser.add_argument( |
| 150 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
| 151 | type=int, | 152 | type=int, |
| 152 | default=200, | 153 | default=600, |
| 153 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 154 | ) | 155 | ) |
| 155 | parser.add_argument( | 156 | parser.add_argument( |
| 157 | "--use_ema", | ||
| 158 | action="store_true", | ||
| 159 | default=True, | ||
| 160 | help="Whether to use EMA model." | ||
| 161 | ) | ||
| 162 | parser.add_argument( | ||
| 163 | "--ema_inv_gamma", | ||
| 164 | type=float, | ||
| 165 | default=1.0 | ||
| 166 | ) | ||
| 167 | parser.add_argument( | ||
| 168 | "--ema_power", | ||
| 169 | type=float, | ||
| 170 | default=1.0 | ||
| 171 | ) | ||
| 172 | parser.add_argument( | ||
| 173 | "--ema_max_decay", | ||
| 174 | type=float, | ||
| 175 | default=0.9999 | ||
| 176 | ) | ||
| 177 | parser.add_argument( | ||
| 156 | "--use_8bit_adam", | 178 | "--use_8bit_adam", |
| 157 | action="store_true", | 179 | action="store_true", |
| 158 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 180 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -326,6 +348,7 @@ class Checkpointer: | |||
| 326 | unet, | 348 | unet, |
| 327 | tokenizer, | 349 | tokenizer, |
| 328 | text_encoder, | 350 | text_encoder, |
| 351 | ema_text_encoder, | ||
| 329 | placeholder_token, | 352 | placeholder_token, |
| 330 | placeholder_token_id, | 353 | placeholder_token_id, |
| 331 | output_dir: Path, | 354 | output_dir: Path, |
| @@ -340,6 +363,7 @@ class Checkpointer: | |||
| 340 | self.unet = unet | 363 | self.unet = unet |
| 341 | self.tokenizer = tokenizer | 364 | self.tokenizer = tokenizer |
| 342 | self.text_encoder = text_encoder | 365 | self.text_encoder = text_encoder |
| 366 | self.ema_text_encoder = ema_text_encoder | ||
| 343 | self.placeholder_token = placeholder_token | 367 | self.placeholder_token = placeholder_token |
| 344 | self.placeholder_token_id = placeholder_token_id | 368 | self.placeholder_token_id = placeholder_token_id |
| 345 | self.output_dir = output_dir | 369 | self.output_dir = output_dir |
| @@ -356,7 +380,8 @@ class Checkpointer: | |||
| 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 380 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 381 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
| 358 | 382 | ||
| 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 383 | unwrapped = self.accelerator.unwrap_model( |
| 384 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
| 360 | 385 | ||
| 361 | # Save a checkpoint | 386 | # Save a checkpoint |
| 362 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 387 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
| @@ -375,7 +400,8 @@ class Checkpointer: | |||
| 375 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 400 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 376 | samples_path = Path(self.output_dir).joinpath("samples") | 401 | samples_path = Path(self.output_dir).joinpath("samples") |
| 377 | 402 | ||
| 378 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 403 | unwrapped = self.accelerator.unwrap_model( |
| 404 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
| 379 | scheduler = EulerAScheduler( | 405 | scheduler = EulerAScheduler( |
| 380 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 406 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 381 | ) | 407 | ) |
| @@ -681,6 +707,13 @@ def main(): | |||
| 681 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 707 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 682 | ) | 708 | ) |
| 683 | 709 | ||
| 710 | ema_text_encoder = EMAModel( | ||
| 711 | text_encoder, | ||
| 712 | inv_gamma=args.ema_inv_gamma, | ||
| 713 | power=args.ema_power, | ||
| 714 | max_value=args.ema_max_decay | ||
| 715 | ) if args.use_ema else None | ||
| 716 | |||
| 684 | # Move vae and unet to device | 717 | # Move vae and unet to device |
| 685 | vae.to(accelerator.device) | 718 | vae.to(accelerator.device) |
| 686 | unet.to(accelerator.device) | 719 | unet.to(accelerator.device) |
| @@ -724,6 +757,7 @@ def main(): | |||
| 724 | unet=unet, | 757 | unet=unet, |
| 725 | tokenizer=tokenizer, | 758 | tokenizer=tokenizer, |
| 726 | text_encoder=text_encoder, | 759 | text_encoder=text_encoder, |
| 760 | ema_text_encoder=ema_text_encoder, | ||
| 727 | placeholder_token=args.placeholder_token, | 761 | placeholder_token=args.placeholder_token, |
| 728 | placeholder_token_id=placeholder_token_id, | 762 | placeholder_token_id=placeholder_token_id, |
| 729 | output_dir=basepath, | 763 | output_dir=basepath, |
| @@ -825,6 +859,9 @@ def main(): | |||
| 825 | 859 | ||
| 826 | # Checks if the accelerator has performed an optimization step behind the scenes | 860 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 827 | if accelerator.sync_gradients: | 861 | if accelerator.sync_gradients: |
| 862 | if args.use_ema: | ||
| 863 | ema_text_encoder.step(unet) | ||
| 864 | |||
| 828 | local_progress_bar.update(1) | 865 | local_progress_bar.update(1) |
| 829 | global_progress_bar.update(1) | 866 | global_progress_bar.update(1) |
| 830 | 867 | ||
| @@ -843,7 +880,12 @@ def main(): | |||
| 843 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 880 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
| 844 | }) | 881 | }) |
| 845 | 882 | ||
| 846 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 883 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 884 | if args.use_ema: | ||
| 885 | logs["ema_decay"] = ema_text_encoder.decay | ||
| 886 | |||
| 887 | accelerator.log(logs, step=global_step) | ||
| 888 | |||
| 847 | local_progress_bar.set_postfix(**logs) | 889 | local_progress_bar.set_postfix(**logs) |
| 848 | 890 | ||
| 849 | if global_step >= args.max_train_steps: | 891 | if global_step >= args.max_train_steps: |
| @@ -884,12 +926,12 @@ def main(): | |||
| 884 | local_progress_bar.update(1) | 926 | local_progress_bar.update(1) |
| 885 | global_progress_bar.update(1) | 927 | global_progress_bar.update(1) |
| 886 | 928 | ||
| 887 | logs = {"mode": "validation", "loss": loss} | 929 | logs = {"val/loss": loss} |
| 888 | local_progress_bar.set_postfix(**logs) | 930 | local_progress_bar.set_postfix(**logs) |
| 889 | 931 | ||
| 890 | val_loss /= len(val_dataloader) | 932 | val_loss /= len(val_dataloader) |
| 891 | 933 | ||
| 892 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 934 | accelerator.log({"val/loss": val_loss}, step=global_step) |
| 893 | 935 | ||
| 894 | local_progress_bar.clear() | 936 | local_progress_bar.clear() |
| 895 | global_progress_bar.clear() | 937 | global_progress_bar.clear() |
