summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-10 17:55:08 +0200
committerVolpeon <git@volpeon.ink>2022-10-10 17:55:08 +0200
commit35116fdf6fb1aedbe0da3cfa9372d53ddb455a26 (patch)
tree682e93c7b81c343fc64ecb5859e650083df15e4f
parentRemove unused code (diff)
downloadtextual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.gz
textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.bz2
textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.zip
Added EMA support to Textual Inversion
-rw-r--r--dreambooth.py23
-rw-r--r--textual_inversion.py54
2 files changed, 60 insertions, 17 deletions
diff --git a/dreambooth.py b/dreambooth.py
index f7d31d2..02f83c6 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -129,7 +129,7 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate", 130 "--learning_rate",
131 type=float, 131 type=float,
132 default=1e-4, 132 default=1e-6,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
@@ -150,7 +150,7 @@ def parse_args():
150 parser.add_argument( 150 parser.add_argument(
151 "--lr_warmup_steps", 151 "--lr_warmup_steps",
152 type=int, 152 type=int,
153 default=200, 153 default=600,
154 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
155 ) 155 )
156 parser.add_argument( 156 parser.add_argument(
@@ -162,12 +162,12 @@ def parse_args():
162 parser.add_argument( 162 parser.add_argument(
163 "--ema_inv_gamma", 163 "--ema_inv_gamma",
164 type=float, 164 type=float,
165 default=0.1 165 default=1.0
166 ) 166 )
167 parser.add_argument( 167 parser.add_argument(
168 "--ema_power", 168 "--ema_power",
169 type=float, 169 type=float,
170 default=1 170 default=1.0
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
173 "--ema_max_decay", 173 "--ema_max_decay",
@@ -783,7 +783,12 @@ def main():
783 if global_step % args.sample_frequency == 0: 783 if global_step % args.sample_frequency == 0:
784 sample_checkpoint = True 784 sample_checkpoint = True
785 785
786 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 786 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
787 if args.use_ema:
788 logs["ema_decay"] = ema_unet.decay
789
790 accelerator.log(logs, step=global_step)
791
787 local_progress_bar.set_postfix(**logs) 792 local_progress_bar.set_postfix(**logs)
788 793
789 if global_step >= args.max_train_steps: 794 if global_step >= args.max_train_steps:
@@ -824,16 +829,12 @@ def main():
824 local_progress_bar.update(1) 829 local_progress_bar.update(1)
825 global_progress_bar.update(1) 830 global_progress_bar.update(1)
826 831
827 logs = {"mode": "validation", "loss": loss} 832 logs = {"val/loss": loss}
828 local_progress_bar.set_postfix(**logs) 833 local_progress_bar.set_postfix(**logs)
829 834
830 val_loss /= len(val_dataloader) 835 val_loss /= len(val_dataloader)
831 836
832 accelerator.log({ 837 accelerator.log({"val/loss": val_loss}, step=global_step)
833 "train/loss": train_loss,
834 "val/loss": val_loss,
835 "lr": lr_scheduler.get_last_lr()[0]
836 }, step=global_step)
837 838
838 local_progress_bar.clear() 839 local_progress_bar.clear()
839 global_progress_bar.clear() 840 global_progress_bar.clear()
diff --git a/textual_inversion.py b/textual_inversion.py
index b01bdbc..e6d856a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -17,6 +17,7 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from diffusers.training_utils import EMAModel
20from PIL import Image 21from PIL import Image
21from tqdm.auto import tqdm 22from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
@@ -149,10 +150,31 @@ def parse_args():
149 parser.add_argument( 150 parser.add_argument(
150 "--lr_warmup_steps", 151 "--lr_warmup_steps",
151 type=int, 152 type=int,
152 default=200, 153 default=600,
153 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
154 ) 155 )
155 parser.add_argument( 156 parser.add_argument(
157 "--use_ema",
158 action="store_true",
159 default=True,
160 help="Whether to use EMA model."
161 )
162 parser.add_argument(
163 "--ema_inv_gamma",
164 type=float,
165 default=1.0
166 )
167 parser.add_argument(
168 "--ema_power",
169 type=float,
170 default=1.0
171 )
172 parser.add_argument(
173 "--ema_max_decay",
174 type=float,
175 default=0.9999
176 )
177 parser.add_argument(
156 "--use_8bit_adam", 178 "--use_8bit_adam",
157 action="store_true", 179 action="store_true",
158 help="Whether or not to use 8-bit Adam from bitsandbytes." 180 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -326,6 +348,7 @@ class Checkpointer:
326 unet, 348 unet,
327 tokenizer, 349 tokenizer,
328 text_encoder, 350 text_encoder,
351 ema_text_encoder,
329 placeholder_token, 352 placeholder_token,
330 placeholder_token_id, 353 placeholder_token_id,
331 output_dir: Path, 354 output_dir: Path,
@@ -340,6 +363,7 @@ class Checkpointer:
340 self.unet = unet 363 self.unet = unet
341 self.tokenizer = tokenizer 364 self.tokenizer = tokenizer
342 self.text_encoder = text_encoder 365 self.text_encoder = text_encoder
366 self.ema_text_encoder = ema_text_encoder
343 self.placeholder_token = placeholder_token 367 self.placeholder_token = placeholder_token
344 self.placeholder_token_id = placeholder_token_id 368 self.placeholder_token_id = placeholder_token_id
345 self.output_dir = output_dir 369 self.output_dir = output_dir
@@ -356,7 +380,8 @@ class Checkpointer:
356 checkpoints_path = self.output_dir.joinpath("checkpoints") 380 checkpoints_path = self.output_dir.joinpath("checkpoints")
357 checkpoints_path.mkdir(parents=True, exist_ok=True) 381 checkpoints_path.mkdir(parents=True, exist_ok=True)
358 382
359 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 383 unwrapped = self.accelerator.unwrap_model(
384 self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder)
360 385
361 # Save a checkpoint 386 # Save a checkpoint
362 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 387 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
@@ -375,7 +400,8 @@ class Checkpointer:
375 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 400 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
376 samples_path = Path(self.output_dir).joinpath("samples") 401 samples_path = Path(self.output_dir).joinpath("samples")
377 402
378 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 403 unwrapped = self.accelerator.unwrap_model(
404 self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder)
379 scheduler = EulerAScheduler( 405 scheduler = EulerAScheduler(
380 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 406 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
381 ) 407 )
@@ -681,6 +707,13 @@ def main():
681 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 707 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
682 ) 708 )
683 709
710 ema_text_encoder = EMAModel(
711 text_encoder,
712 inv_gamma=args.ema_inv_gamma,
713 power=args.ema_power,
714 max_value=args.ema_max_decay
715 ) if args.use_ema else None
716
684 # Move vae and unet to device 717 # Move vae and unet to device
685 vae.to(accelerator.device) 718 vae.to(accelerator.device)
686 unet.to(accelerator.device) 719 unet.to(accelerator.device)
@@ -724,6 +757,7 @@ def main():
724 unet=unet, 757 unet=unet,
725 tokenizer=tokenizer, 758 tokenizer=tokenizer,
726 text_encoder=text_encoder, 759 text_encoder=text_encoder,
760 ema_text_encoder=ema_text_encoder,
727 placeholder_token=args.placeholder_token, 761 placeholder_token=args.placeholder_token,
728 placeholder_token_id=placeholder_token_id, 762 placeholder_token_id=placeholder_token_id,
729 output_dir=basepath, 763 output_dir=basepath,
@@ -825,6 +859,9 @@ def main():
825 859
826 # Checks if the accelerator has performed an optimization step behind the scenes 860 # Checks if the accelerator has performed an optimization step behind the scenes
827 if accelerator.sync_gradients: 861 if accelerator.sync_gradients:
862 if args.use_ema:
863 ema_text_encoder.step(unet)
864
828 local_progress_bar.update(1) 865 local_progress_bar.update(1)
829 global_progress_bar.update(1) 866 global_progress_bar.update(1)
830 867
@@ -843,7 +880,12 @@ def main():
843 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 880 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
844 }) 881 })
845 882
846 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 883 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
884 if args.use_ema:
885 logs["ema_decay"] = ema_text_encoder.decay
886
887 accelerator.log(logs, step=global_step)
888
847 local_progress_bar.set_postfix(**logs) 889 local_progress_bar.set_postfix(**logs)
848 890
849 if global_step >= args.max_train_steps: 891 if global_step >= args.max_train_steps:
@@ -884,12 +926,12 @@ def main():
884 local_progress_bar.update(1) 926 local_progress_bar.update(1)
885 global_progress_bar.update(1) 927 global_progress_bar.update(1)
886 928
887 logs = {"mode": "validation", "loss": loss} 929 logs = {"val/loss": loss}
888 local_progress_bar.set_postfix(**logs) 930 local_progress_bar.set_postfix(**logs)
889 931
890 val_loss /= len(val_dataloader) 932 val_loss /= len(val_dataloader)
891 933
892 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 934 accelerator.log({"val/loss": val_loss}, step=global_step)
893 935
894 local_progress_bar.clear() 936 local_progress_bar.clear()
895 global_progress_bar.clear() 937 global_progress_bar.clear()