summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py74
1 files changed, 14 insertions, 60 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index e6d856a..3a3741d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from diffusers.training_utils import EMAModel
21from PIL import Image 20from PIL import Image
22from tqdm.auto import tqdm 21from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
@@ -112,7 +111,7 @@ def parse_args():
112 parser.add_argument( 111 parser.add_argument(
113 "--max_train_steps", 112 "--max_train_steps",
114 type=int, 113 type=int,
115 default=5000, 114 default=3000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 116 )
118 parser.add_argument( 117 parser.add_argument(
@@ -150,31 +149,10 @@ def parse_args():
150 parser.add_argument( 149 parser.add_argument(
151 "--lr_warmup_steps", 150 "--lr_warmup_steps",
152 type=int, 151 type=int,
153 default=600, 152 default=500,
154 help="Number of steps for the warmup in the lr scheduler." 153 help="Number of steps for the warmup in the lr scheduler."
155 ) 154 )
156 parser.add_argument( 155 parser.add_argument(
157 "--use_ema",
158 action="store_true",
159 default=True,
160 help="Whether to use EMA model."
161 )
162 parser.add_argument(
163 "--ema_inv_gamma",
164 type=float,
165 default=1.0
166 )
167 parser.add_argument(
168 "--ema_power",
169 type=float,
170 default=1.0
171 )
172 parser.add_argument(
173 "--ema_max_decay",
174 type=float,
175 default=0.9999
176 )
177 parser.add_argument(
178 "--use_8bit_adam", 156 "--use_8bit_adam",
179 action="store_true", 157 action="store_true",
180 help="Whether or not to use 8-bit Adam from bitsandbytes." 158 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -348,7 +326,6 @@ class Checkpointer:
348 unet, 326 unet,
349 tokenizer, 327 tokenizer,
350 text_encoder, 328 text_encoder,
351 ema_text_encoder,
352 placeholder_token, 329 placeholder_token,
353 placeholder_token_id, 330 placeholder_token_id,
354 output_dir: Path, 331 output_dir: Path,
@@ -363,7 +340,6 @@ class Checkpointer:
363 self.unet = unet 340 self.unet = unet
364 self.tokenizer = tokenizer 341 self.tokenizer = tokenizer
365 self.text_encoder = text_encoder 342 self.text_encoder = text_encoder
366 self.ema_text_encoder = ema_text_encoder
367 self.placeholder_token = placeholder_token 343 self.placeholder_token = placeholder_token
368 self.placeholder_token_id = placeholder_token_id 344 self.placeholder_token_id = placeholder_token_id
369 self.output_dir = output_dir 345 self.output_dir = output_dir
@@ -380,8 +356,7 @@ class Checkpointer:
380 checkpoints_path = self.output_dir.joinpath("checkpoints") 356 checkpoints_path = self.output_dir.joinpath("checkpoints")
381 checkpoints_path.mkdir(parents=True, exist_ok=True) 357 checkpoints_path.mkdir(parents=True, exist_ok=True)
382 358
383 unwrapped = self.accelerator.unwrap_model( 359 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
384 self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder)
385 360
386 # Save a checkpoint 361 # Save a checkpoint
387 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 362 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
@@ -400,8 +375,7 @@ class Checkpointer:
400 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 375 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
401 samples_path = Path(self.output_dir).joinpath("samples") 376 samples_path = Path(self.output_dir).joinpath("samples")
402 377
403 unwrapped = self.accelerator.unwrap_model( 378 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
404 self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder)
405 scheduler = EulerAScheduler( 379 scheduler = EulerAScheduler(
406 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 380 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
407 ) 381 )
@@ -507,9 +481,7 @@ def main():
507 if args.tokenizer_name: 481 if args.tokenizer_name:
508 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 482 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
509 elif args.pretrained_model_name_or_path: 483 elif args.pretrained_model_name_or_path:
510 tokenizer = CLIPTokenizer.from_pretrained( 484 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
511 args.pretrained_model_name_or_path + '/tokenizer'
512 )
513 485
514 # Add the placeholder token in tokenizer 486 # Add the placeholder token in tokenizer
515 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 487 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
@@ -530,15 +502,10 @@ def main():
530 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 502 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
531 503
532 # Load models and create wrapper for stable diffusion 504 # Load models and create wrapper for stable diffusion
533 text_encoder = CLIPTextModel.from_pretrained( 505 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
534 args.pretrained_model_name_or_path + '/text_encoder', 506 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
535 )
536 vae = AutoencoderKL.from_pretrained(
537 args.pretrained_model_name_or_path + '/vae',
538 )
539 unet = UNet2DConditionModel.from_pretrained( 507 unet = UNet2DConditionModel.from_pretrained(
540 args.pretrained_model_name_or_path + '/unet', 508 args.pretrained_model_name_or_path, subfolder='unet')
541 )
542 509
543 if args.gradient_checkpointing: 510 if args.gradient_checkpointing:
544 unet.enable_gradient_checkpointing() 511 unet.enable_gradient_checkpointing()
@@ -707,13 +674,6 @@ def main():
707 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 674 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
708 ) 675 )
709 676
710 ema_text_encoder = EMAModel(
711 text_encoder,
712 inv_gamma=args.ema_inv_gamma,
713 power=args.ema_power,
714 max_value=args.ema_max_decay
715 ) if args.use_ema else None
716
717 # Move vae and unet to device 677 # Move vae and unet to device
718 vae.to(accelerator.device) 678 vae.to(accelerator.device)
719 unet.to(accelerator.device) 679 unet.to(accelerator.device)
@@ -757,7 +717,6 @@ def main():
757 unet=unet, 717 unet=unet,
758 tokenizer=tokenizer, 718 tokenizer=tokenizer,
759 text_encoder=text_encoder, 719 text_encoder=text_encoder,
760 ema_text_encoder=ema_text_encoder,
761 placeholder_token=args.placeholder_token, 720 placeholder_token=args.placeholder_token,
762 placeholder_token_id=placeholder_token_id, 721 placeholder_token_id=placeholder_token_id,
763 output_dir=basepath, 722 output_dir=basepath,
@@ -777,7 +736,7 @@ def main():
777 disable=not accelerator.is_local_main_process, 736 disable=not accelerator.is_local_main_process,
778 dynamic_ncols=True 737 dynamic_ncols=True
779 ) 738 )
780 local_progress_bar.set_description("Batch X out of Y") 739 local_progress_bar.set_description("Epoch X / Y")
781 740
782 global_progress_bar = tqdm( 741 global_progress_bar = tqdm(
783 range(args.max_train_steps + val_steps), 742 range(args.max_train_steps + val_steps),
@@ -788,7 +747,7 @@ def main():
788 747
789 try: 748 try:
790 for epoch in range(num_epochs): 749 for epoch in range(num_epochs):
791 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") 750 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
792 local_progress_bar.reset() 751 local_progress_bar.reset()
793 752
794 text_encoder.train() 753 text_encoder.train()
@@ -799,9 +758,8 @@ def main():
799 for step, batch in enumerate(train_dataloader): 758 for step, batch in enumerate(train_dataloader):
800 with accelerator.accumulate(text_encoder): 759 with accelerator.accumulate(text_encoder):
801 # Convert images to latent space 760 # Convert images to latent space
802 with torch.no_grad(): 761 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
803 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 762 latents = latents * 0.18215
804 latents = latents * 0.18215
805 763
806 # Sample noise that we'll add to the latents 764 # Sample noise that we'll add to the latents
807 noise = torch.randn(latents.shape).to(latents.device) 765 noise = torch.randn(latents.shape).to(latents.device)
@@ -859,9 +817,6 @@ def main():
859 817
860 # Checks if the accelerator has performed an optimization step behind the scenes 818 # Checks if the accelerator has performed an optimization step behind the scenes
861 if accelerator.sync_gradients: 819 if accelerator.sync_gradients:
862 if args.use_ema:
863 ema_text_encoder.step(unet)
864
865 local_progress_bar.update(1) 820 local_progress_bar.update(1)
866 global_progress_bar.update(1) 821 global_progress_bar.update(1)
867 822
@@ -881,8 +836,6 @@ def main():
881 }) 836 })
882 837
883 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 838 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
884 if args.use_ema:
885 logs["ema_decay"] = ema_text_encoder.decay
886 839
887 accelerator.log(logs, step=global_step) 840 accelerator.log(logs, step=global_step)
888 841
@@ -937,7 +890,8 @@ def main():
937 global_progress_bar.clear() 890 global_progress_bar.clear()
938 891
939 if min_val_loss > val_loss: 892 if min_val_loss > val_loss:
940 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 893 accelerator.print(
894 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
941 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 895 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
942 min_val_loss = val_loss 896 min_val_loss = val_loss
943 897