diff options
author | Volpeon <git@volpeon.ink> | 2022-10-10 17:55:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-10 17:55:08 +0200 |
commit | 35116fdf6fb1aedbe0da3cfa9372d53ddb455a26 (patch) | |
tree | 682e93c7b81c343fc64ecb5859e650083df15e4f | |
parent | Remove unused code (diff) | |
download | textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.gz textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.bz2 textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.zip |
Added EMA support to Textual Inversion
-rw-r--r-- | dreambooth.py | 23 | ||||
-rw-r--r-- | textual_inversion.py | 54 |
2 files changed, 60 insertions, 17 deletions
diff --git a/dreambooth.py b/dreambooth.py index f7d31d2..02f83c6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -129,7 +129,7 @@ def parse_args(): | |||
129 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--learning_rate", | 130 | "--learning_rate", |
131 | type=float, | 131 | type=float, |
132 | default=1e-4, | 132 | default=1e-6, |
133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
@@ -150,7 +150,7 @@ def parse_args(): | |||
150 | parser.add_argument( | 150 | parser.add_argument( |
151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
152 | type=int, | 152 | type=int, |
153 | default=200, | 153 | default=600, |
154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
155 | ) | 155 | ) |
156 | parser.add_argument( | 156 | parser.add_argument( |
@@ -162,12 +162,12 @@ def parse_args(): | |||
162 | parser.add_argument( | 162 | parser.add_argument( |
163 | "--ema_inv_gamma", | 163 | "--ema_inv_gamma", |
164 | type=float, | 164 | type=float, |
165 | default=0.1 | 165 | default=1.0 |
166 | ) | 166 | ) |
167 | parser.add_argument( | 167 | parser.add_argument( |
168 | "--ema_power", | 168 | "--ema_power", |
169 | type=float, | 169 | type=float, |
170 | default=1 | 170 | default=1.0 |
171 | ) | 171 | ) |
172 | parser.add_argument( | 172 | parser.add_argument( |
173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
@@ -783,7 +783,12 @@ def main(): | |||
783 | if global_step % args.sample_frequency == 0: | 783 | if global_step % args.sample_frequency == 0: |
784 | sample_checkpoint = True | 784 | sample_checkpoint = True |
785 | 785 | ||
786 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 786 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
787 | if args.use_ema: | ||
788 | logs["ema_decay"] = ema_unet.decay | ||
789 | |||
790 | accelerator.log(logs, step=global_step) | ||
791 | |||
787 | local_progress_bar.set_postfix(**logs) | 792 | local_progress_bar.set_postfix(**logs) |
788 | 793 | ||
789 | if global_step >= args.max_train_steps: | 794 | if global_step >= args.max_train_steps: |
@@ -824,16 +829,12 @@ def main(): | |||
824 | local_progress_bar.update(1) | 829 | local_progress_bar.update(1) |
825 | global_progress_bar.update(1) | 830 | global_progress_bar.update(1) |
826 | 831 | ||
827 | logs = {"mode": "validation", "loss": loss} | 832 | logs = {"val/loss": loss} |
828 | local_progress_bar.set_postfix(**logs) | 833 | local_progress_bar.set_postfix(**logs) |
829 | 834 | ||
830 | val_loss /= len(val_dataloader) | 835 | val_loss /= len(val_dataloader) |
831 | 836 | ||
832 | accelerator.log({ | 837 | accelerator.log({"val/loss": val_loss}, step=global_step) |
833 | "train/loss": train_loss, | ||
834 | "val/loss": val_loss, | ||
835 | "lr": lr_scheduler.get_last_lr()[0] | ||
836 | }, step=global_step) | ||
837 | 838 | ||
838 | local_progress_bar.clear() | 839 | local_progress_bar.clear() |
839 | global_progress_bar.clear() | 840 | global_progress_bar.clear() |
diff --git a/textual_inversion.py b/textual_inversion.py index b01bdbc..e6d856a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -17,6 +17,7 @@ from accelerate.logging import get_logger | |||
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
20 | from diffusers.training_utils import EMAModel | ||
20 | from PIL import Image | 21 | from PIL import Image |
21 | from tqdm.auto import tqdm | 22 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
@@ -149,10 +150,31 @@ def parse_args(): | |||
149 | parser.add_argument( | 150 | parser.add_argument( |
150 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
151 | type=int, | 152 | type=int, |
152 | default=200, | 153 | default=600, |
153 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
154 | ) | 155 | ) |
155 | parser.add_argument( | 156 | parser.add_argument( |
157 | "--use_ema", | ||
158 | action="store_true", | ||
159 | default=True, | ||
160 | help="Whether to use EMA model." | ||
161 | ) | ||
162 | parser.add_argument( | ||
163 | "--ema_inv_gamma", | ||
164 | type=float, | ||
165 | default=1.0 | ||
166 | ) | ||
167 | parser.add_argument( | ||
168 | "--ema_power", | ||
169 | type=float, | ||
170 | default=1.0 | ||
171 | ) | ||
172 | parser.add_argument( | ||
173 | "--ema_max_decay", | ||
174 | type=float, | ||
175 | default=0.9999 | ||
176 | ) | ||
177 | parser.add_argument( | ||
156 | "--use_8bit_adam", | 178 | "--use_8bit_adam", |
157 | action="store_true", | 179 | action="store_true", |
158 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 180 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -326,6 +348,7 @@ class Checkpointer: | |||
326 | unet, | 348 | unet, |
327 | tokenizer, | 349 | tokenizer, |
328 | text_encoder, | 350 | text_encoder, |
351 | ema_text_encoder, | ||
329 | placeholder_token, | 352 | placeholder_token, |
330 | placeholder_token_id, | 353 | placeholder_token_id, |
331 | output_dir: Path, | 354 | output_dir: Path, |
@@ -340,6 +363,7 @@ class Checkpointer: | |||
340 | self.unet = unet | 363 | self.unet = unet |
341 | self.tokenizer = tokenizer | 364 | self.tokenizer = tokenizer |
342 | self.text_encoder = text_encoder | 365 | self.text_encoder = text_encoder |
366 | self.ema_text_encoder = ema_text_encoder | ||
343 | self.placeholder_token = placeholder_token | 367 | self.placeholder_token = placeholder_token |
344 | self.placeholder_token_id = placeholder_token_id | 368 | self.placeholder_token_id = placeholder_token_id |
345 | self.output_dir = output_dir | 369 | self.output_dir = output_dir |
@@ -356,7 +380,8 @@ class Checkpointer: | |||
356 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 380 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
357 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 381 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
358 | 382 | ||
359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 383 | unwrapped = self.accelerator.unwrap_model( |
384 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
360 | 385 | ||
361 | # Save a checkpoint | 386 | # Save a checkpoint |
362 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 387 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
@@ -375,7 +400,8 @@ class Checkpointer: | |||
375 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 400 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
376 | samples_path = Path(self.output_dir).joinpath("samples") | 401 | samples_path = Path(self.output_dir).joinpath("samples") |
377 | 402 | ||
378 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 403 | unwrapped = self.accelerator.unwrap_model( |
404 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
379 | scheduler = EulerAScheduler( | 405 | scheduler = EulerAScheduler( |
380 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 406 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
381 | ) | 407 | ) |
@@ -681,6 +707,13 @@ def main(): | |||
681 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 707 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
682 | ) | 708 | ) |
683 | 709 | ||
710 | ema_text_encoder = EMAModel( | ||
711 | text_encoder, | ||
712 | inv_gamma=args.ema_inv_gamma, | ||
713 | power=args.ema_power, | ||
714 | max_value=args.ema_max_decay | ||
715 | ) if args.use_ema else None | ||
716 | |||
684 | # Move vae and unet to device | 717 | # Move vae and unet to device |
685 | vae.to(accelerator.device) | 718 | vae.to(accelerator.device) |
686 | unet.to(accelerator.device) | 719 | unet.to(accelerator.device) |
@@ -724,6 +757,7 @@ def main(): | |||
724 | unet=unet, | 757 | unet=unet, |
725 | tokenizer=tokenizer, | 758 | tokenizer=tokenizer, |
726 | text_encoder=text_encoder, | 759 | text_encoder=text_encoder, |
760 | ema_text_encoder=ema_text_encoder, | ||
727 | placeholder_token=args.placeholder_token, | 761 | placeholder_token=args.placeholder_token, |
728 | placeholder_token_id=placeholder_token_id, | 762 | placeholder_token_id=placeholder_token_id, |
729 | output_dir=basepath, | 763 | output_dir=basepath, |
@@ -825,6 +859,9 @@ def main(): | |||
825 | 859 | ||
826 | # Checks if the accelerator has performed an optimization step behind the scenes | 860 | # Checks if the accelerator has performed an optimization step behind the scenes |
827 | if accelerator.sync_gradients: | 861 | if accelerator.sync_gradients: |
862 | if args.use_ema: | ||
863 | ema_text_encoder.step(unet) | ||
864 | |||
828 | local_progress_bar.update(1) | 865 | local_progress_bar.update(1) |
829 | global_progress_bar.update(1) | 866 | global_progress_bar.update(1) |
830 | 867 | ||
@@ -843,7 +880,12 @@ def main(): | |||
843 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 880 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
844 | }) | 881 | }) |
845 | 882 | ||
846 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 883 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
884 | if args.use_ema: | ||
885 | logs["ema_decay"] = ema_text_encoder.decay | ||
886 | |||
887 | accelerator.log(logs, step=global_step) | ||
888 | |||
847 | local_progress_bar.set_postfix(**logs) | 889 | local_progress_bar.set_postfix(**logs) |
848 | 890 | ||
849 | if global_step >= args.max_train_steps: | 891 | if global_step >= args.max_train_steps: |
@@ -884,12 +926,12 @@ def main(): | |||
884 | local_progress_bar.update(1) | 926 | local_progress_bar.update(1) |
885 | global_progress_bar.update(1) | 927 | global_progress_bar.update(1) |
886 | 928 | ||
887 | logs = {"mode": "validation", "loss": loss} | 929 | logs = {"val/loss": loss} |
888 | local_progress_bar.set_postfix(**logs) | 930 | local_progress_bar.set_postfix(**logs) |
889 | 931 | ||
890 | val_loss /= len(val_dataloader) | 932 | val_loss /= len(val_dataloader) |
891 | 933 | ||
892 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 934 | accelerator.log({"val/loss": val_loss}, step=global_step) |
893 | 935 | ||
894 | local_progress_bar.clear() | 936 | local_progress_bar.clear() |
895 | global_progress_bar.clear() | 937 | global_progress_bar.clear() |