diff options
author | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
commit | 9d6c75262b6919758e781b8333428861a5bf7ede (patch) | |
tree | 72e5814413c18d476813867d87c8360c14aee200 /train_ti.py | |
parent | Set default dimensions to 768; add config inheritance (diff) | |
download | textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2 textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip |
Added learning rate finder
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 174 |
1 files changed, 87 insertions, 87 deletions
diff --git a/train_ti.py b/train_ti.py index 6e30ac3..ab00b60 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,10 +1,8 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | 2 | import itertools |
3 | import math | 3 | import math |
4 | import os | ||
5 | import datetime | 4 | import datetime |
6 | import logging | 5 | import logging |
7 | import json | ||
8 | from pathlib import Path | 6 | from pathlib import Path |
9 | 7 | ||
10 | import torch | 8 | import torch |
@@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config | |||
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 23 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 24 | from training.optimization import get_one_cycle_schedule |
25 | from training.lr import LRFinder | ||
27 | from training.ti import patch_trainable_embeddings | 26 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 27 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params |
29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
@@ -173,6 +172,11 @@ def parse_args(): | |||
173 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 172 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
174 | ) | 173 | ) |
175 | parser.add_argument( | 174 | parser.add_argument( |
175 | "--find_lr", | ||
176 | action="store_true", | ||
177 | help="Automatically find a learning rate (no training).", | ||
178 | ) | ||
179 | parser.add_argument( | ||
176 | "--learning_rate", | 180 | "--learning_rate", |
177 | type=float, | 181 | type=float, |
178 | default=1e-4, | 182 | default=1e-4, |
@@ -225,7 +229,7 @@ def parse_args(): | |||
225 | parser.add_argument( | 229 | parser.add_argument( |
226 | "--adam_weight_decay", | 230 | "--adam_weight_decay", |
227 | type=float, | 231 | type=float, |
228 | default=0, | 232 | default=1e-2, |
229 | help="Weight decay to use." | 233 | help="Weight decay to use." |
230 | ) | 234 | ) |
231 | parser.add_argument( | 235 | parser.add_argument( |
@@ -447,16 +451,23 @@ def main(): | |||
447 | global_step_offset = args.global_step | 451 | global_step_offset = args.global_step |
448 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 452 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
449 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 453 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
450 | basepath.mkdir(parents=True, exist_ok=True) | ||
451 | 454 | ||
452 | accelerator = Accelerator( | 455 | if args.find_lr: |
453 | log_with=LoggerType.TENSORBOARD, | 456 | accelerator = Accelerator( |
454 | logging_dir=f"{basepath}", | 457 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
455 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 458 | mixed_precision=args.mixed_precision |
456 | mixed_precision=args.mixed_precision | 459 | ) |
457 | ) | 460 | else: |
461 | basepath.mkdir(parents=True, exist_ok=True) | ||
458 | 462 | ||
459 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 463 | accelerator = Accelerator( |
464 | log_with=LoggerType.TENSORBOARD, | ||
465 | logging_dir=f"{basepath}", | ||
466 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
467 | mixed_precision=args.mixed_precision | ||
468 | ) | ||
469 | |||
470 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
460 | 471 | ||
461 | args.seed = args.seed or (torch.random.seed() >> 32) | 472 | args.seed = args.seed or (torch.random.seed() >> 32) |
462 | set_seed(args.seed) | 473 | set_seed(args.seed) |
@@ -537,6 +548,9 @@ def main(): | |||
537 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
538 | ) | 549 | ) |
539 | 550 | ||
551 | if args.find_lr: | ||
552 | args.learning_rate = 1e2 | ||
553 | |||
540 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
541 | if args.use_8bit_adam: | 555 | if args.use_8bit_adam: |
542 | try: | 556 | try: |
@@ -671,7 +685,9 @@ def main(): | |||
671 | 685 | ||
672 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 686 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps |
673 | 687 | ||
674 | if args.lr_scheduler == "one_cycle": | 688 | if args.find_lr: |
689 | lr_scheduler = None | ||
690 | elif args.lr_scheduler == "one_cycle": | ||
675 | lr_scheduler = get_one_cycle_schedule( | 691 | lr_scheduler = get_one_cycle_schedule( |
676 | optimizer=optimizer, | 692 | optimizer=optimizer, |
677 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 693 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
@@ -713,6 +729,63 @@ def main(): | |||
713 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 729 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
714 | val_steps = num_val_steps_per_epoch * num_epochs | 730 | val_steps = num_val_steps_per_epoch * num_epochs |
715 | 731 | ||
732 | def loop(batch): | ||
733 | # Convert images to latent space | ||
734 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
735 | latents = latents * 0.18215 | ||
736 | |||
737 | # Sample noise that we'll add to the latents | ||
738 | noise = torch.randn_like(latents) | ||
739 | bsz = latents.shape[0] | ||
740 | # Sample a random timestep for each image | ||
741 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
742 | (bsz,), device=latents.device) | ||
743 | timesteps = timesteps.long() | ||
744 | |||
745 | # Add noise to the latents according to the noise magnitude at each timestep | ||
746 | # (this is the forward diffusion process) | ||
747 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
748 | |||
749 | # Get the text embedding for conditioning | ||
750 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
751 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
752 | |||
753 | # Predict the noise residual | ||
754 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
755 | |||
756 | # Get the target for loss depending on the prediction type | ||
757 | if noise_scheduler.config.prediction_type == "epsilon": | ||
758 | target = noise | ||
759 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
760 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
761 | else: | ||
762 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
763 | |||
764 | if args.num_class_images != 0: | ||
765 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
766 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
767 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
768 | |||
769 | # Compute instance loss | ||
770 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
771 | |||
772 | # Compute prior loss | ||
773 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
774 | |||
775 | # Add the prior loss to the instance loss. | ||
776 | loss = loss + args.prior_loss_weight * prior_loss | ||
777 | else: | ||
778 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
779 | |||
780 | acc = (model_pred == latents).float().mean() | ||
781 | |||
782 | return loss, acc, bsz | ||
783 | |||
784 | if args.find_lr: | ||
785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) | ||
786 | lr_finder.run() | ||
787 | quit() | ||
788 | |||
716 | # We need to initialize the trackers we use, and also store our configuration. | 789 | # We need to initialize the trackers we use, and also store our configuration. |
717 | # The trackers initializes automatically on the main process. | 790 | # The trackers initializes automatically on the main process. |
718 | if accelerator.is_main_process: | 791 | if accelerator.is_main_process: |
@@ -786,54 +859,7 @@ def main(): | |||
786 | 859 | ||
787 | for step, batch in enumerate(train_dataloader): | 860 | for step, batch in enumerate(train_dataloader): |
788 | with accelerator.accumulate(text_encoder): | 861 | with accelerator.accumulate(text_encoder): |
789 | # Convert images to latent space | 862 | loss, acc, bsz = loop(batch) |
790 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
791 | latents = latents * 0.18215 | ||
792 | |||
793 | # Sample noise that we'll add to the latents | ||
794 | noise = torch.randn_like(latents) | ||
795 | bsz = latents.shape[0] | ||
796 | # Sample a random timestep for each image | ||
797 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
798 | (bsz,), device=latents.device) | ||
799 | timesteps = timesteps.long() | ||
800 | |||
801 | # Add noise to the latents according to the noise magnitude at each timestep | ||
802 | # (this is the forward diffusion process) | ||
803 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
804 | |||
805 | # Get the text embedding for conditioning | ||
806 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
807 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
808 | |||
809 | # Predict the noise residual | ||
810 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
811 | |||
812 | # Get the target for loss depending on the prediction type | ||
813 | if noise_scheduler.config.prediction_type == "epsilon": | ||
814 | target = noise | ||
815 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
816 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
817 | else: | ||
818 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
819 | |||
820 | if args.num_class_images != 0: | ||
821 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
822 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
823 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
824 | |||
825 | # Compute instance loss | ||
826 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
827 | |||
828 | # Compute prior loss | ||
829 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
830 | |||
831 | # Add the prior loss to the instance loss. | ||
832 | loss = loss + args.prior_loss_weight * prior_loss | ||
833 | else: | ||
834 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
835 | |||
836 | acc = (model_pred == latents).float().mean() | ||
837 | 863 | ||
838 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
839 | 865 | ||
@@ -873,33 +899,7 @@ def main(): | |||
873 | 899 | ||
874 | with torch.inference_mode(): | 900 | with torch.inference_mode(): |
875 | for step, batch in enumerate(val_dataloader): | 901 | for step, batch in enumerate(val_dataloader): |
876 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 902 | loss, acc, bsz = loop(batch) |
877 | latents = latents * 0.18215 | ||
878 | |||
879 | noise = torch.randn_like(latents) | ||
880 | bsz = latents.shape[0] | ||
881 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
882 | (bsz,), device=latents.device) | ||
883 | timesteps = timesteps.long() | ||
884 | |||
885 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
886 | |||
887 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
888 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
889 | |||
890 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
891 | |||
892 | # Get the target for loss depending on the prediction type | ||
893 | if noise_scheduler.config.prediction_type == "epsilon": | ||
894 | target = noise | ||
895 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
896 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
897 | else: | ||
898 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
899 | |||
900 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
901 | |||
902 | acc = (model_pred == latents).float().mean() | ||
903 | 903 | ||
904 | avg_loss_val.update(loss.detach_(), bsz) | 904 | avg_loss_val.update(loss.detach_(), bsz) |
905 | avg_acc_val.update(acc.detach_(), bsz) | 905 | avg_acc_val.update(acc.detach_(), bsz) |