diff options
-rw-r--r-- | environment.yaml | 1 | ||||
-rw-r--r-- | train_dreambooth.py | 129 | ||||
-rw-r--r-- | train_ti.py | 174 | ||||
-rw-r--r-- | training/lr.py | 115 |
4 files changed, 257 insertions, 162 deletions
diff --git a/environment.yaml b/environment.yaml index 179fa38..c006379 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -5,6 +5,7 @@ channels: | |||
5 | - defaults | 5 | - defaults |
6 | dependencies: | 6 | dependencies: |
7 | - cudatoolkit=11.3 | 7 | - cudatoolkit=11.3 |
8 | - matplotlib=3.6.2 | ||
8 | - numpy=1.23.4 | 9 | - numpy=1.23.4 |
9 | - pip=22.3.1 | 10 | - pip=22.3.1 |
10 | - python=3.9.15 | 11 | - python=3.9.15 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 08bc9e0..a62cec9 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -843,6 +843,58 @@ def main(): | |||
843 | ) | 843 | ) |
844 | global_progress_bar.set_description("Total progress") | 844 | global_progress_bar.set_description("Total progress") |
845 | 845 | ||
846 | def loop(batch): | ||
847 | # Convert images to latent space | ||
848 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
849 | latents = latents * 0.18215 | ||
850 | |||
851 | # Sample noise that we'll add to the latents | ||
852 | noise = torch.randn_like(latents) | ||
853 | bsz = latents.shape[0] | ||
854 | # Sample a random timestep for each image | ||
855 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
856 | (bsz,), device=latents.device) | ||
857 | timesteps = timesteps.long() | ||
858 | |||
859 | # Add noise to the latents according to the noise magnitude at each timestep | ||
860 | # (this is the forward diffusion process) | ||
861 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
862 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
863 | |||
864 | # Get the text embedding for conditioning | ||
865 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
866 | |||
867 | # Predict the noise residual | ||
868 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
869 | |||
870 | # Get the target for loss depending on the prediction type | ||
871 | if noise_scheduler.config.prediction_type == "epsilon": | ||
872 | target = noise | ||
873 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
874 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
875 | else: | ||
876 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
877 | |||
878 | if args.num_class_images != 0: | ||
879 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
880 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
881 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
882 | |||
883 | # Compute instance loss | ||
884 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
885 | |||
886 | # Compute prior loss | ||
887 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
888 | |||
889 | # Add the prior loss to the instance loss. | ||
890 | loss = loss + args.prior_loss_weight * prior_loss | ||
891 | else: | ||
892 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
893 | |||
894 | acc = (model_pred == latents).float().mean() | ||
895 | |||
896 | return loss, acc, bsz | ||
897 | |||
846 | try: | 898 | try: |
847 | for epoch in range(num_epochs): | 899 | for epoch in range(num_epochs): |
848 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 900 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -859,54 +911,7 @@ def main(): | |||
859 | 911 | ||
860 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
861 | with accelerator.accumulate(unet): | 913 | with accelerator.accumulate(unet): |
862 | # Convert images to latent space | 914 | loss, acc, bsz = loop(batch) |
863 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
864 | latents = latents * 0.18215 | ||
865 | |||
866 | # Sample noise that we'll add to the latents | ||
867 | noise = torch.randn_like(latents) | ||
868 | bsz = latents.shape[0] | ||
869 | # Sample a random timestep for each image | ||
870 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
871 | (bsz,), device=latents.device) | ||
872 | timesteps = timesteps.long() | ||
873 | |||
874 | # Add noise to the latents according to the noise magnitude at each timestep | ||
875 | # (this is the forward diffusion process) | ||
876 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
877 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
878 | |||
879 | # Get the text embedding for conditioning | ||
880 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
881 | |||
882 | # Predict the noise residual | ||
883 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
884 | |||
885 | # Get the target for loss depending on the prediction type | ||
886 | if noise_scheduler.config.prediction_type == "epsilon": | ||
887 | target = noise | ||
888 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
889 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
890 | else: | ||
891 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
892 | |||
893 | if args.num_class_images != 0: | ||
894 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
895 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
896 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
897 | |||
898 | # Compute instance loss | ||
899 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
900 | |||
901 | # Compute prior loss | ||
902 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
903 | |||
904 | # Add the prior loss to the instance loss. | ||
905 | loss = loss + args.prior_loss_weight * prior_loss | ||
906 | else: | ||
907 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
908 | |||
909 | acc = (model_pred == latents).float().mean() | ||
910 | 915 | ||
911 | accelerator.backward(loss) | 916 | accelerator.backward(loss) |
912 | 917 | ||
@@ -960,33 +965,7 @@ def main(): | |||
960 | 965 | ||
961 | with torch.inference_mode(): | 966 | with torch.inference_mode(): |
962 | for step, batch in enumerate(val_dataloader): | 967 | for step, batch in enumerate(val_dataloader): |
963 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 968 | loss, acc, bsz = loop(batch) |
964 | latents = latents * 0.18215 | ||
965 | |||
966 | noise = torch.randn_like(latents) | ||
967 | bsz = latents.shape[0] | ||
968 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
969 | (bsz,), device=latents.device) | ||
970 | timesteps = timesteps.long() | ||
971 | |||
972 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
973 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
974 | |||
975 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
976 | |||
977 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
978 | |||
979 | # Get the target for loss depending on the prediction type | ||
980 | if noise_scheduler.config.prediction_type == "epsilon": | ||
981 | target = noise | ||
982 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
983 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
984 | else: | ||
985 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
986 | |||
987 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
988 | |||
989 | acc = (model_pred == latents).float().mean() | ||
990 | 969 | ||
991 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
992 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |
diff --git a/train_ti.py b/train_ti.py index 6e30ac3..ab00b60 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,10 +1,8 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | 2 | import itertools |
3 | import math | 3 | import math |
4 | import os | ||
5 | import datetime | 4 | import datetime |
6 | import logging | 5 | import logging |
7 | import json | ||
8 | from pathlib import Path | 6 | from pathlib import Path |
9 | 7 | ||
10 | import torch | 8 | import torch |
@@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config | |||
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 23 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 24 | from training.optimization import get_one_cycle_schedule |
25 | from training.lr import LRFinder | ||
27 | from training.ti import patch_trainable_embeddings | 26 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 27 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params |
29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
@@ -173,6 +172,11 @@ def parse_args(): | |||
173 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 172 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
174 | ) | 173 | ) |
175 | parser.add_argument( | 174 | parser.add_argument( |
175 | "--find_lr", | ||
176 | action="store_true", | ||
177 | help="Automatically find a learning rate (no training).", | ||
178 | ) | ||
179 | parser.add_argument( | ||
176 | "--learning_rate", | 180 | "--learning_rate", |
177 | type=float, | 181 | type=float, |
178 | default=1e-4, | 182 | default=1e-4, |
@@ -225,7 +229,7 @@ def parse_args(): | |||
225 | parser.add_argument( | 229 | parser.add_argument( |
226 | "--adam_weight_decay", | 230 | "--adam_weight_decay", |
227 | type=float, | 231 | type=float, |
228 | default=0, | 232 | default=1e-2, |
229 | help="Weight decay to use." | 233 | help="Weight decay to use." |
230 | ) | 234 | ) |
231 | parser.add_argument( | 235 | parser.add_argument( |
@@ -447,16 +451,23 @@ def main(): | |||
447 | global_step_offset = args.global_step | 451 | global_step_offset = args.global_step |
448 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 452 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
449 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 453 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
450 | basepath.mkdir(parents=True, exist_ok=True) | ||
451 | 454 | ||
452 | accelerator = Accelerator( | 455 | if args.find_lr: |
453 | log_with=LoggerType.TENSORBOARD, | 456 | accelerator = Accelerator( |
454 | logging_dir=f"{basepath}", | 457 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
455 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 458 | mixed_precision=args.mixed_precision |
456 | mixed_precision=args.mixed_precision | 459 | ) |
457 | ) | 460 | else: |
461 | basepath.mkdir(parents=True, exist_ok=True) | ||
458 | 462 | ||
459 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 463 | accelerator = Accelerator( |
464 | log_with=LoggerType.TENSORBOARD, | ||
465 | logging_dir=f"{basepath}", | ||
466 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
467 | mixed_precision=args.mixed_precision | ||
468 | ) | ||
469 | |||
470 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
460 | 471 | ||
461 | args.seed = args.seed or (torch.random.seed() >> 32) | 472 | args.seed = args.seed or (torch.random.seed() >> 32) |
462 | set_seed(args.seed) | 473 | set_seed(args.seed) |
@@ -537,6 +548,9 @@ def main(): | |||
537 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
538 | ) | 549 | ) |
539 | 550 | ||
551 | if args.find_lr: | ||
552 | args.learning_rate = 1e2 | ||
553 | |||
540 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
541 | if args.use_8bit_adam: | 555 | if args.use_8bit_adam: |
542 | try: | 556 | try: |
@@ -671,7 +685,9 @@ def main(): | |||
671 | 685 | ||
672 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 686 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps |
673 | 687 | ||
674 | if args.lr_scheduler == "one_cycle": | 688 | if args.find_lr: |
689 | lr_scheduler = None | ||
690 | elif args.lr_scheduler == "one_cycle": | ||
675 | lr_scheduler = get_one_cycle_schedule( | 691 | lr_scheduler = get_one_cycle_schedule( |
676 | optimizer=optimizer, | 692 | optimizer=optimizer, |
677 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 693 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
@@ -713,6 +729,63 @@ def main(): | |||
713 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 729 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
714 | val_steps = num_val_steps_per_epoch * num_epochs | 730 | val_steps = num_val_steps_per_epoch * num_epochs |
715 | 731 | ||
732 | def loop(batch): | ||
733 | # Convert images to latent space | ||
734 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
735 | latents = latents * 0.18215 | ||
736 | |||
737 | # Sample noise that we'll add to the latents | ||
738 | noise = torch.randn_like(latents) | ||
739 | bsz = latents.shape[0] | ||
740 | # Sample a random timestep for each image | ||
741 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
742 | (bsz,), device=latents.device) | ||
743 | timesteps = timesteps.long() | ||
744 | |||
745 | # Add noise to the latents according to the noise magnitude at each timestep | ||
746 | # (this is the forward diffusion process) | ||
747 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
748 | |||
749 | # Get the text embedding for conditioning | ||
750 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
751 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
752 | |||
753 | # Predict the noise residual | ||
754 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
755 | |||
756 | # Get the target for loss depending on the prediction type | ||
757 | if noise_scheduler.config.prediction_type == "epsilon": | ||
758 | target = noise | ||
759 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
760 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
761 | else: | ||
762 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
763 | |||
764 | if args.num_class_images != 0: | ||
765 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
766 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
767 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
768 | |||
769 | # Compute instance loss | ||
770 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
771 | |||
772 | # Compute prior loss | ||
773 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
774 | |||
775 | # Add the prior loss to the instance loss. | ||
776 | loss = loss + args.prior_loss_weight * prior_loss | ||
777 | else: | ||
778 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
779 | |||
780 | acc = (model_pred == latents).float().mean() | ||
781 | |||
782 | return loss, acc, bsz | ||
783 | |||
784 | if args.find_lr: | ||
785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) | ||
786 | lr_finder.run() | ||
787 | quit() | ||
788 | |||
716 | # We need to initialize the trackers we use, and also store our configuration. | 789 | # We need to initialize the trackers we use, and also store our configuration. |
717 | # The trackers initializes automatically on the main process. | 790 | # The trackers initializes automatically on the main process. |
718 | if accelerator.is_main_process: | 791 | if accelerator.is_main_process: |
@@ -786,54 +859,7 @@ def main(): | |||
786 | 859 | ||
787 | for step, batch in enumerate(train_dataloader): | 860 | for step, batch in enumerate(train_dataloader): |
788 | with accelerator.accumulate(text_encoder): | 861 | with accelerator.accumulate(text_encoder): |
789 | # Convert images to latent space | 862 | loss, acc, bsz = loop(batch) |
790 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
791 | latents = latents * 0.18215 | ||
792 | |||
793 | # Sample noise that we'll add to the latents | ||
794 | noise = torch.randn_like(latents) | ||
795 | bsz = latents.shape[0] | ||
796 | # Sample a random timestep for each image | ||
797 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
798 | (bsz,), device=latents.device) | ||
799 | timesteps = timesteps.long() | ||
800 | |||
801 | # Add noise to the latents according to the noise magnitude at each timestep | ||
802 | # (this is the forward diffusion process) | ||
803 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
804 | |||
805 | # Get the text embedding for conditioning | ||
806 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
807 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
808 | |||
809 | # Predict the noise residual | ||
810 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
811 | |||
812 | # Get the target for loss depending on the prediction type | ||
813 | if noise_scheduler.config.prediction_type == "epsilon": | ||
814 | target = noise | ||
815 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
816 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
817 | else: | ||
818 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
819 | |||
820 | if args.num_class_images != 0: | ||
821 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
822 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
823 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
824 | |||
825 | # Compute instance loss | ||
826 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
827 | |||
828 | # Compute prior loss | ||
829 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
830 | |||
831 | # Add the prior loss to the instance loss. | ||
832 | loss = loss + args.prior_loss_weight * prior_loss | ||
833 | else: | ||
834 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
835 | |||
836 | acc = (model_pred == latents).float().mean() | ||
837 | 863 | ||
838 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
839 | 865 | ||
@@ -873,33 +899,7 @@ def main(): | |||
873 | 899 | ||
874 | with torch.inference_mode(): | 900 | with torch.inference_mode(): |
875 | for step, batch in enumerate(val_dataloader): | 901 | for step, batch in enumerate(val_dataloader): |
876 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 902 | loss, acc, bsz = loop(batch) |
877 | latents = latents * 0.18215 | ||
878 | |||
879 | noise = torch.randn_like(latents) | ||
880 | bsz = latents.shape[0] | ||
881 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
882 | (bsz,), device=latents.device) | ||
883 | timesteps = timesteps.long() | ||
884 | |||
885 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
886 | |||
887 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
888 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
889 | |||
890 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
891 | |||
892 | # Get the target for loss depending on the prediction type | ||
893 | if noise_scheduler.config.prediction_type == "epsilon": | ||
894 | target = noise | ||
895 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
896 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
897 | else: | ||
898 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
899 | |||
900 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
901 | |||
902 | acc = (model_pred == latents).float().mean() | ||
903 | 903 | ||
904 | avg_loss_val.update(loss.detach_(), bsz) | 904 | avg_loss_val.update(loss.detach_(), bsz) |
905 | avg_acc_val.update(acc.detach_(), bsz) | 905 | avg_acc_val.update(acc.detach_(), bsz) |
diff --git a/training/lr.py b/training/lr.py new file mode 100644 index 0000000..dd37baa --- /dev/null +++ b/training/lr.py | |||
@@ -0,0 +1,115 @@ | |||
1 | import numpy as np | ||
2 | from torch.optim.lr_scheduler import LambdaLR | ||
3 | from tqdm.auto import tqdm | ||
4 | import matplotlib.pyplot as plt | ||
5 | |||
6 | from training.util import AverageMeter | ||
7 | |||
8 | |||
9 | class LRFinder(): | ||
10 | def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): | ||
11 | self.accelerator = accelerator | ||
12 | self.model = model | ||
13 | self.optimizer = optimizer | ||
14 | self.train_dataloader = train_dataloader | ||
15 | self.loss_fn = loss_fn | ||
16 | |||
17 | def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): | ||
18 | best_loss = None | ||
19 | lrs = [] | ||
20 | losses = [] | ||
21 | |||
22 | lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) | ||
23 | |||
24 | progress_bar = tqdm( | ||
25 | range(num_epochs * num_steps), | ||
26 | disable=not self.accelerator.is_local_main_process, | ||
27 | dynamic_ncols=True | ||
28 | ) | ||
29 | progress_bar.set_description("Epoch X / Y") | ||
30 | |||
31 | for epoch in range(num_epochs): | ||
32 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
33 | |||
34 | avg_loss = AverageMeter() | ||
35 | |||
36 | for step, batch in enumerate(self.train_dataloader): | ||
37 | with self.accelerator.accumulate(self.model): | ||
38 | loss, acc, bsz = self.loss_fn(batch) | ||
39 | |||
40 | self.accelerator.backward(loss) | ||
41 | |||
42 | self.optimizer.step() | ||
43 | self.optimizer.zero_grad(set_to_none=True) | ||
44 | |||
45 | avg_loss.update(loss.detach_(), bsz) | ||
46 | |||
47 | if step >= num_steps: | ||
48 | break | ||
49 | |||
50 | if self.accelerator.sync_gradients: | ||
51 | progress_bar.update(1) | ||
52 | |||
53 | lr_scheduler.step() | ||
54 | |||
55 | loss = avg_loss.avg.item() | ||
56 | if epoch == 0: | ||
57 | best_loss = loss | ||
58 | else: | ||
59 | if smooth_f > 0: | ||
60 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | ||
61 | if loss < best_loss: | ||
62 | best_loss = loss | ||
63 | |||
64 | lr = lr_scheduler.get_last_lr()[0] | ||
65 | |||
66 | lrs.append(lr) | ||
67 | losses.append(loss) | ||
68 | |||
69 | progress_bar.set_postfix({ | ||
70 | "loss": loss, | ||
71 | "best": best_loss, | ||
72 | "lr": lr, | ||
73 | }) | ||
74 | |||
75 | if loss > diverge_th * best_loss: | ||
76 | print("Stopping early, the loss has diverged") | ||
77 | break | ||
78 | |||
79 | fig, ax = plt.subplots() | ||
80 | ax.plot(lrs, losses) | ||
81 | |||
82 | print("LR suggestion: steepest gradient") | ||
83 | min_grad_idx = None | ||
84 | try: | ||
85 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | ||
86 | except ValueError: | ||
87 | print( | ||
88 | "Failed to compute the gradients, there might not be enough points." | ||
89 | ) | ||
90 | if min_grad_idx is not None: | ||
91 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | ||
92 | ax.scatter( | ||
93 | lrs[min_grad_idx], | ||
94 | losses[min_grad_idx], | ||
95 | s=75, | ||
96 | marker="o", | ||
97 | color="red", | ||
98 | zorder=3, | ||
99 | label="steepest gradient", | ||
100 | ) | ||
101 | ax.legend() | ||
102 | |||
103 | ax.set_xscale("log") | ||
104 | ax.set_xlabel("Learning rate") | ||
105 | ax.set_ylabel("Loss") | ||
106 | |||
107 | if fig is not None: | ||
108 | plt.show() | ||
109 | |||
110 | |||
111 | def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): | ||
112 | def lr_lambda(current_epoch: int): | ||
113 | return (current_epoch / num_epochs) ** 5 | ||
114 | |||
115 | return LambdaLR(optimizer, lr_lambda, last_epoch) | ||