diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 74 |
1 files changed, 14 insertions, 60 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index e6d856a..3a3741d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
20 | from diffusers.training_utils import EMAModel | ||
21 | from PIL import Image | 20 | from PIL import Image |
22 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
23 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
@@ -112,7 +111,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 111 | parser.add_argument( |
113 | "--max_train_steps", | 112 | "--max_train_steps", |
114 | type=int, | 113 | type=int, |
115 | default=5000, | 114 | default=3000, |
116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
117 | ) | 116 | ) |
118 | parser.add_argument( | 117 | parser.add_argument( |
@@ -150,31 +149,10 @@ def parse_args(): | |||
150 | parser.add_argument( | 149 | parser.add_argument( |
151 | "--lr_warmup_steps", | 150 | "--lr_warmup_steps", |
152 | type=int, | 151 | type=int, |
153 | default=600, | 152 | default=500, |
154 | help="Number of steps for the warmup in the lr scheduler." | 153 | help="Number of steps for the warmup in the lr scheduler." |
155 | ) | 154 | ) |
156 | parser.add_argument( | 155 | parser.add_argument( |
157 | "--use_ema", | ||
158 | action="store_true", | ||
159 | default=True, | ||
160 | help="Whether to use EMA model." | ||
161 | ) | ||
162 | parser.add_argument( | ||
163 | "--ema_inv_gamma", | ||
164 | type=float, | ||
165 | default=1.0 | ||
166 | ) | ||
167 | parser.add_argument( | ||
168 | "--ema_power", | ||
169 | type=float, | ||
170 | default=1.0 | ||
171 | ) | ||
172 | parser.add_argument( | ||
173 | "--ema_max_decay", | ||
174 | type=float, | ||
175 | default=0.9999 | ||
176 | ) | ||
177 | parser.add_argument( | ||
178 | "--use_8bit_adam", | 156 | "--use_8bit_adam", |
179 | action="store_true", | 157 | action="store_true", |
180 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 158 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -348,7 +326,6 @@ class Checkpointer: | |||
348 | unet, | 326 | unet, |
349 | tokenizer, | 327 | tokenizer, |
350 | text_encoder, | 328 | text_encoder, |
351 | ema_text_encoder, | ||
352 | placeholder_token, | 329 | placeholder_token, |
353 | placeholder_token_id, | 330 | placeholder_token_id, |
354 | output_dir: Path, | 331 | output_dir: Path, |
@@ -363,7 +340,6 @@ class Checkpointer: | |||
363 | self.unet = unet | 340 | self.unet = unet |
364 | self.tokenizer = tokenizer | 341 | self.tokenizer = tokenizer |
365 | self.text_encoder = text_encoder | 342 | self.text_encoder = text_encoder |
366 | self.ema_text_encoder = ema_text_encoder | ||
367 | self.placeholder_token = placeholder_token | 343 | self.placeholder_token = placeholder_token |
368 | self.placeholder_token_id = placeholder_token_id | 344 | self.placeholder_token_id = placeholder_token_id |
369 | self.output_dir = output_dir | 345 | self.output_dir = output_dir |
@@ -380,8 +356,7 @@ class Checkpointer: | |||
380 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
381 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
382 | 358 | ||
383 | unwrapped = self.accelerator.unwrap_model( | 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
384 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
385 | 360 | ||
386 | # Save a checkpoint | 361 | # Save a checkpoint |
387 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 362 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
@@ -400,8 +375,7 @@ class Checkpointer: | |||
400 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 375 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
401 | samples_path = Path(self.output_dir).joinpath("samples") | 376 | samples_path = Path(self.output_dir).joinpath("samples") |
402 | 377 | ||
403 | unwrapped = self.accelerator.unwrap_model( | 378 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
404 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
405 | scheduler = EulerAScheduler( | 379 | scheduler = EulerAScheduler( |
406 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 380 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
407 | ) | 381 | ) |
@@ -507,9 +481,7 @@ def main(): | |||
507 | if args.tokenizer_name: | 481 | if args.tokenizer_name: |
508 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 482 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
509 | elif args.pretrained_model_name_or_path: | 483 | elif args.pretrained_model_name_or_path: |
510 | tokenizer = CLIPTokenizer.from_pretrained( | 484 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
511 | args.pretrained_model_name_or_path + '/tokenizer' | ||
512 | ) | ||
513 | 485 | ||
514 | # Add the placeholder token in tokenizer | 486 | # Add the placeholder token in tokenizer |
515 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 487 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
@@ -530,15 +502,10 @@ def main(): | |||
530 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 502 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
531 | 503 | ||
532 | # Load models and create wrapper for stable diffusion | 504 | # Load models and create wrapper for stable diffusion |
533 | text_encoder = CLIPTextModel.from_pretrained( | 505 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
534 | args.pretrained_model_name_or_path + '/text_encoder', | 506 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
535 | ) | ||
536 | vae = AutoencoderKL.from_pretrained( | ||
537 | args.pretrained_model_name_or_path + '/vae', | ||
538 | ) | ||
539 | unet = UNet2DConditionModel.from_pretrained( | 507 | unet = UNet2DConditionModel.from_pretrained( |
540 | args.pretrained_model_name_or_path + '/unet', | 508 | args.pretrained_model_name_or_path, subfolder='unet') |
541 | ) | ||
542 | 509 | ||
543 | if args.gradient_checkpointing: | 510 | if args.gradient_checkpointing: |
544 | unet.enable_gradient_checkpointing() | 511 | unet.enable_gradient_checkpointing() |
@@ -707,13 +674,6 @@ def main(): | |||
707 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 674 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
708 | ) | 675 | ) |
709 | 676 | ||
710 | ema_text_encoder = EMAModel( | ||
711 | text_encoder, | ||
712 | inv_gamma=args.ema_inv_gamma, | ||
713 | power=args.ema_power, | ||
714 | max_value=args.ema_max_decay | ||
715 | ) if args.use_ema else None | ||
716 | |||
717 | # Move vae and unet to device | 677 | # Move vae and unet to device |
718 | vae.to(accelerator.device) | 678 | vae.to(accelerator.device) |
719 | unet.to(accelerator.device) | 679 | unet.to(accelerator.device) |
@@ -757,7 +717,6 @@ def main(): | |||
757 | unet=unet, | 717 | unet=unet, |
758 | tokenizer=tokenizer, | 718 | tokenizer=tokenizer, |
759 | text_encoder=text_encoder, | 719 | text_encoder=text_encoder, |
760 | ema_text_encoder=ema_text_encoder, | ||
761 | placeholder_token=args.placeholder_token, | 720 | placeholder_token=args.placeholder_token, |
762 | placeholder_token_id=placeholder_token_id, | 721 | placeholder_token_id=placeholder_token_id, |
763 | output_dir=basepath, | 722 | output_dir=basepath, |
@@ -777,7 +736,7 @@ def main(): | |||
777 | disable=not accelerator.is_local_main_process, | 736 | disable=not accelerator.is_local_main_process, |
778 | dynamic_ncols=True | 737 | dynamic_ncols=True |
779 | ) | 738 | ) |
780 | local_progress_bar.set_description("Batch X out of Y") | 739 | local_progress_bar.set_description("Epoch X / Y") |
781 | 740 | ||
782 | global_progress_bar = tqdm( | 741 | global_progress_bar = tqdm( |
783 | range(args.max_train_steps + val_steps), | 742 | range(args.max_train_steps + val_steps), |
@@ -788,7 +747,7 @@ def main(): | |||
788 | 747 | ||
789 | try: | 748 | try: |
790 | for epoch in range(num_epochs): | 749 | for epoch in range(num_epochs): |
791 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | 750 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
792 | local_progress_bar.reset() | 751 | local_progress_bar.reset() |
793 | 752 | ||
794 | text_encoder.train() | 753 | text_encoder.train() |
@@ -799,9 +758,8 @@ def main(): | |||
799 | for step, batch in enumerate(train_dataloader): | 758 | for step, batch in enumerate(train_dataloader): |
800 | with accelerator.accumulate(text_encoder): | 759 | with accelerator.accumulate(text_encoder): |
801 | # Convert images to latent space | 760 | # Convert images to latent space |
802 | with torch.no_grad(): | 761 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
803 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 762 | latents = latents * 0.18215 |
804 | latents = latents * 0.18215 | ||
805 | 763 | ||
806 | # Sample noise that we'll add to the latents | 764 | # Sample noise that we'll add to the latents |
807 | noise = torch.randn(latents.shape).to(latents.device) | 765 | noise = torch.randn(latents.shape).to(latents.device) |
@@ -859,9 +817,6 @@ def main(): | |||
859 | 817 | ||
860 | # Checks if the accelerator has performed an optimization step behind the scenes | 818 | # Checks if the accelerator has performed an optimization step behind the scenes |
861 | if accelerator.sync_gradients: | 819 | if accelerator.sync_gradients: |
862 | if args.use_ema: | ||
863 | ema_text_encoder.step(unet) | ||
864 | |||
865 | local_progress_bar.update(1) | 820 | local_progress_bar.update(1) |
866 | global_progress_bar.update(1) | 821 | global_progress_bar.update(1) |
867 | 822 | ||
@@ -881,8 +836,6 @@ def main(): | |||
881 | }) | 836 | }) |
882 | 837 | ||
883 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 838 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
884 | if args.use_ema: | ||
885 | logs["ema_decay"] = ema_text_encoder.decay | ||
886 | 839 | ||
887 | accelerator.log(logs, step=global_step) | 840 | accelerator.log(logs, step=global_step) |
888 | 841 | ||
@@ -937,7 +890,8 @@ def main(): | |||
937 | global_progress_bar.clear() | 890 | global_progress_bar.clear() |
938 | 891 | ||
939 | if min_val_loss > val_loss: | 892 | if min_val_loss > val_loss: |
940 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 893 | accelerator.print( |
894 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
941 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 895 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
942 | min_val_loss = val_loss | 896 | min_val_loss = val_loss |
943 | 897 | ||