diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 74 |
1 files changed, 14 insertions, 60 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index e6d856a..3a3741d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 20 | from diffusers.training_utils import EMAModel | ||
| 21 | from PIL import Image | 20 | from PIL import Image |
| 22 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| @@ -112,7 +111,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 111 | parser.add_argument( |
| 113 | "--max_train_steps", | 112 | "--max_train_steps", |
| 114 | type=int, | 113 | type=int, |
| 115 | default=5000, | 114 | default=3000, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 116 | ) |
| 118 | parser.add_argument( | 117 | parser.add_argument( |
| @@ -150,31 +149,10 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 149 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 150 | "--lr_warmup_steps", |
| 152 | type=int, | 151 | type=int, |
| 153 | default=600, | 152 | default=500, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 153 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 154 | ) |
| 156 | parser.add_argument( | 155 | parser.add_argument( |
| 157 | "--use_ema", | ||
| 158 | action="store_true", | ||
| 159 | default=True, | ||
| 160 | help="Whether to use EMA model." | ||
| 161 | ) | ||
| 162 | parser.add_argument( | ||
| 163 | "--ema_inv_gamma", | ||
| 164 | type=float, | ||
| 165 | default=1.0 | ||
| 166 | ) | ||
| 167 | parser.add_argument( | ||
| 168 | "--ema_power", | ||
| 169 | type=float, | ||
| 170 | default=1.0 | ||
| 171 | ) | ||
| 172 | parser.add_argument( | ||
| 173 | "--ema_max_decay", | ||
| 174 | type=float, | ||
| 175 | default=0.9999 | ||
| 176 | ) | ||
| 177 | parser.add_argument( | ||
| 178 | "--use_8bit_adam", | 156 | "--use_8bit_adam", |
| 179 | action="store_true", | 157 | action="store_true", |
| 180 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 158 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -348,7 +326,6 @@ class Checkpointer: | |||
| 348 | unet, | 326 | unet, |
| 349 | tokenizer, | 327 | tokenizer, |
| 350 | text_encoder, | 328 | text_encoder, |
| 351 | ema_text_encoder, | ||
| 352 | placeholder_token, | 329 | placeholder_token, |
| 353 | placeholder_token_id, | 330 | placeholder_token_id, |
| 354 | output_dir: Path, | 331 | output_dir: Path, |
| @@ -363,7 +340,6 @@ class Checkpointer: | |||
| 363 | self.unet = unet | 340 | self.unet = unet |
| 364 | self.tokenizer = tokenizer | 341 | self.tokenizer = tokenizer |
| 365 | self.text_encoder = text_encoder | 342 | self.text_encoder = text_encoder |
| 366 | self.ema_text_encoder = ema_text_encoder | ||
| 367 | self.placeholder_token = placeholder_token | 343 | self.placeholder_token = placeholder_token |
| 368 | self.placeholder_token_id = placeholder_token_id | 344 | self.placeholder_token_id = placeholder_token_id |
| 369 | self.output_dir = output_dir | 345 | self.output_dir = output_dir |
| @@ -380,8 +356,7 @@ class Checkpointer: | |||
| 380 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| 381 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
| 382 | 358 | ||
| 383 | unwrapped = self.accelerator.unwrap_model( | 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 384 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
| 385 | 360 | ||
| 386 | # Save a checkpoint | 361 | # Save a checkpoint |
| 387 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 362 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
| @@ -400,8 +375,7 @@ class Checkpointer: | |||
| 400 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 375 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 401 | samples_path = Path(self.output_dir).joinpath("samples") | 376 | samples_path = Path(self.output_dir).joinpath("samples") |
| 402 | 377 | ||
| 403 | unwrapped = self.accelerator.unwrap_model( | 378 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 404 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
| 405 | scheduler = EulerAScheduler( | 379 | scheduler = EulerAScheduler( |
| 406 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 380 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 407 | ) | 381 | ) |
| @@ -507,9 +481,7 @@ def main(): | |||
| 507 | if args.tokenizer_name: | 481 | if args.tokenizer_name: |
| 508 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 482 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 509 | elif args.pretrained_model_name_or_path: | 483 | elif args.pretrained_model_name_or_path: |
| 510 | tokenizer = CLIPTokenizer.from_pretrained( | 484 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 511 | args.pretrained_model_name_or_path + '/tokenizer' | ||
| 512 | ) | ||
| 513 | 485 | ||
| 514 | # Add the placeholder token in tokenizer | 486 | # Add the placeholder token in tokenizer |
| 515 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 487 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| @@ -530,15 +502,10 @@ def main(): | |||
| 530 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 502 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 531 | 503 | ||
| 532 | # Load models and create wrapper for stable diffusion | 504 | # Load models and create wrapper for stable diffusion |
| 533 | text_encoder = CLIPTextModel.from_pretrained( | 505 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 534 | args.pretrained_model_name_or_path + '/text_encoder', | 506 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 535 | ) | ||
| 536 | vae = AutoencoderKL.from_pretrained( | ||
| 537 | args.pretrained_model_name_or_path + '/vae', | ||
| 538 | ) | ||
| 539 | unet = UNet2DConditionModel.from_pretrained( | 507 | unet = UNet2DConditionModel.from_pretrained( |
| 540 | args.pretrained_model_name_or_path + '/unet', | 508 | args.pretrained_model_name_or_path, subfolder='unet') |
| 541 | ) | ||
| 542 | 509 | ||
| 543 | if args.gradient_checkpointing: | 510 | if args.gradient_checkpointing: |
| 544 | unet.enable_gradient_checkpointing() | 511 | unet.enable_gradient_checkpointing() |
| @@ -707,13 +674,6 @@ def main(): | |||
| 707 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 674 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 708 | ) | 675 | ) |
| 709 | 676 | ||
| 710 | ema_text_encoder = EMAModel( | ||
| 711 | text_encoder, | ||
| 712 | inv_gamma=args.ema_inv_gamma, | ||
| 713 | power=args.ema_power, | ||
| 714 | max_value=args.ema_max_decay | ||
| 715 | ) if args.use_ema else None | ||
| 716 | |||
| 717 | # Move vae and unet to device | 677 | # Move vae and unet to device |
| 718 | vae.to(accelerator.device) | 678 | vae.to(accelerator.device) |
| 719 | unet.to(accelerator.device) | 679 | unet.to(accelerator.device) |
| @@ -757,7 +717,6 @@ def main(): | |||
| 757 | unet=unet, | 717 | unet=unet, |
| 758 | tokenizer=tokenizer, | 718 | tokenizer=tokenizer, |
| 759 | text_encoder=text_encoder, | 719 | text_encoder=text_encoder, |
| 760 | ema_text_encoder=ema_text_encoder, | ||
| 761 | placeholder_token=args.placeholder_token, | 720 | placeholder_token=args.placeholder_token, |
| 762 | placeholder_token_id=placeholder_token_id, | 721 | placeholder_token_id=placeholder_token_id, |
| 763 | output_dir=basepath, | 722 | output_dir=basepath, |
| @@ -777,7 +736,7 @@ def main(): | |||
| 777 | disable=not accelerator.is_local_main_process, | 736 | disable=not accelerator.is_local_main_process, |
| 778 | dynamic_ncols=True | 737 | dynamic_ncols=True |
| 779 | ) | 738 | ) |
| 780 | local_progress_bar.set_description("Batch X out of Y") | 739 | local_progress_bar.set_description("Epoch X / Y") |
| 781 | 740 | ||
| 782 | global_progress_bar = tqdm( | 741 | global_progress_bar = tqdm( |
| 783 | range(args.max_train_steps + val_steps), | 742 | range(args.max_train_steps + val_steps), |
| @@ -788,7 +747,7 @@ def main(): | |||
| 788 | 747 | ||
| 789 | try: | 748 | try: |
| 790 | for epoch in range(num_epochs): | 749 | for epoch in range(num_epochs): |
| 791 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | 750 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 792 | local_progress_bar.reset() | 751 | local_progress_bar.reset() |
| 793 | 752 | ||
| 794 | text_encoder.train() | 753 | text_encoder.train() |
| @@ -799,9 +758,8 @@ def main(): | |||
| 799 | for step, batch in enumerate(train_dataloader): | 758 | for step, batch in enumerate(train_dataloader): |
| 800 | with accelerator.accumulate(text_encoder): | 759 | with accelerator.accumulate(text_encoder): |
| 801 | # Convert images to latent space | 760 | # Convert images to latent space |
| 802 | with torch.no_grad(): | 761 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 803 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 762 | latents = latents * 0.18215 |
| 804 | latents = latents * 0.18215 | ||
| 805 | 763 | ||
| 806 | # Sample noise that we'll add to the latents | 764 | # Sample noise that we'll add to the latents |
| 807 | noise = torch.randn(latents.shape).to(latents.device) | 765 | noise = torch.randn(latents.shape).to(latents.device) |
| @@ -859,9 +817,6 @@ def main(): | |||
| 859 | 817 | ||
| 860 | # Checks if the accelerator has performed an optimization step behind the scenes | 818 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 861 | if accelerator.sync_gradients: | 819 | if accelerator.sync_gradients: |
| 862 | if args.use_ema: | ||
| 863 | ema_text_encoder.step(unet) | ||
| 864 | |||
| 865 | local_progress_bar.update(1) | 820 | local_progress_bar.update(1) |
| 866 | global_progress_bar.update(1) | 821 | global_progress_bar.update(1) |
| 867 | 822 | ||
| @@ -881,8 +836,6 @@ def main(): | |||
| 881 | }) | 836 | }) |
| 882 | 837 | ||
| 883 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 838 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 884 | if args.use_ema: | ||
| 885 | logs["ema_decay"] = ema_text_encoder.decay | ||
| 886 | 839 | ||
| 887 | accelerator.log(logs, step=global_step) | 840 | accelerator.log(logs, step=global_step) |
| 888 | 841 | ||
| @@ -937,7 +890,8 @@ def main(): | |||
| 937 | global_progress_bar.clear() | 890 | global_progress_bar.clear() |
| 938 | 891 | ||
| 939 | if min_val_loss > val_loss: | 892 | if min_val_loss > val_loss: |
| 940 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 893 | accelerator.print( |
| 894 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 941 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 895 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
| 942 | min_val_loss = val_loss | 896 | min_val_loss = val_loss |
| 943 | 897 | ||
