summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 11:02:49 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 11:02:49 +0100
commit9d6c75262b6919758e781b8333428861a5bf7ede (patch)
tree72e5814413c18d476813867d87c8360c14aee200
parentSet default dimensions to 768; add config inheritance (diff)
downloadtextual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz
textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2
textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip
Added learning rate finder
-rw-r--r--environment.yaml1
-rw-r--r--train_dreambooth.py129
-rw-r--r--train_ti.py174
-rw-r--r--training/lr.py115
4 files changed, 257 insertions, 162 deletions
diff --git a/environment.yaml b/environment.yaml
index 179fa38..c006379 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -5,6 +5,7 @@ channels:
5 - defaults 5 - defaults
6dependencies: 6dependencies:
7 - cudatoolkit=11.3 7 - cudatoolkit=11.3
8 - matplotlib=3.6.2
8 - numpy=1.23.4 9 - numpy=1.23.4
9 - pip=22.3.1 10 - pip=22.3.1
10 - python=3.9.15 11 - python=3.9.15
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 08bc9e0..a62cec9 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -843,6 +843,58 @@ def main():
843 ) 843 )
844 global_progress_bar.set_description("Total progress") 844 global_progress_bar.set_description("Total progress")
845 845
846 def loop(batch):
847 # Convert images to latent space
848 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
849 latents = latents * 0.18215
850
851 # Sample noise that we'll add to the latents
852 noise = torch.randn_like(latents)
853 bsz = latents.shape[0]
854 # Sample a random timestep for each image
855 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
856 (bsz,), device=latents.device)
857 timesteps = timesteps.long()
858
859 # Add noise to the latents according to the noise magnitude at each timestep
860 # (this is the forward diffusion process)
861 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
862 noisy_latents = noisy_latents.to(dtype=unet.dtype)
863
864 # Get the text embedding for conditioning
865 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
866
867 # Predict the noise residual
868 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
869
870 # Get the target for loss depending on the prediction type
871 if noise_scheduler.config.prediction_type == "epsilon":
872 target = noise
873 elif noise_scheduler.config.prediction_type == "v_prediction":
874 target = noise_scheduler.get_velocity(latents, noise, timesteps)
875 else:
876 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
877
878 if args.num_class_images != 0:
879 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
880 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
881 target, target_prior = torch.chunk(target, 2, dim=0)
882
883 # Compute instance loss
884 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
885
886 # Compute prior loss
887 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
888
889 # Add the prior loss to the instance loss.
890 loss = loss + args.prior_loss_weight * prior_loss
891 else:
892 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
893
894 acc = (model_pred == latents).float().mean()
895
896 return loss, acc, bsz
897
846 try: 898 try:
847 for epoch in range(num_epochs): 899 for epoch in range(num_epochs):
848 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 900 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -859,54 +911,7 @@ def main():
859 911
860 for step, batch in enumerate(train_dataloader): 912 for step, batch in enumerate(train_dataloader):
861 with accelerator.accumulate(unet): 913 with accelerator.accumulate(unet):
862 # Convert images to latent space 914 loss, acc, bsz = loop(batch)
863 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
864 latents = latents * 0.18215
865
866 # Sample noise that we'll add to the latents
867 noise = torch.randn_like(latents)
868 bsz = latents.shape[0]
869 # Sample a random timestep for each image
870 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
871 (bsz,), device=latents.device)
872 timesteps = timesteps.long()
873
874 # Add noise to the latents according to the noise magnitude at each timestep
875 # (this is the forward diffusion process)
876 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
877 noisy_latents = noisy_latents.to(dtype=unet.dtype)
878
879 # Get the text embedding for conditioning
880 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
881
882 # Predict the noise residual
883 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
884
885 # Get the target for loss depending on the prediction type
886 if noise_scheduler.config.prediction_type == "epsilon":
887 target = noise
888 elif noise_scheduler.config.prediction_type == "v_prediction":
889 target = noise_scheduler.get_velocity(latents, noise, timesteps)
890 else:
891 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
892
893 if args.num_class_images != 0:
894 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
895 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
896 target, target_prior = torch.chunk(target, 2, dim=0)
897
898 # Compute instance loss
899 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
900
901 # Compute prior loss
902 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
903
904 # Add the prior loss to the instance loss.
905 loss = loss + args.prior_loss_weight * prior_loss
906 else:
907 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
908
909 acc = (model_pred == latents).float().mean()
910 915
911 accelerator.backward(loss) 916 accelerator.backward(loss)
912 917
@@ -960,33 +965,7 @@ def main():
960 965
961 with torch.inference_mode(): 966 with torch.inference_mode():
962 for step, batch in enumerate(val_dataloader): 967 for step, batch in enumerate(val_dataloader):
963 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 968 loss, acc, bsz = loop(batch)
964 latents = latents * 0.18215
965
966 noise = torch.randn_like(latents)
967 bsz = latents.shape[0]
968 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
969 (bsz,), device=latents.device)
970 timesteps = timesteps.long()
971
972 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
973 noisy_latents = noisy_latents.to(dtype=unet.dtype)
974
975 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
976
977 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
978
979 # Get the target for loss depending on the prediction type
980 if noise_scheduler.config.prediction_type == "epsilon":
981 target = noise
982 elif noise_scheduler.config.prediction_type == "v_prediction":
983 target = noise_scheduler.get_velocity(latents, noise, timesteps)
984 else:
985 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
986
987 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
988
989 acc = (model_pred == latents).float().mean()
990 969
991 avg_loss_val.update(loss.detach_(), bsz) 970 avg_loss_val.update(loss.detach_(), bsz)
992 avg_acc_val.update(acc.detach_(), bsz) 971 avg_acc_val.update(acc.detach_(), bsz)
diff --git a/train_ti.py b/train_ti.py
index 6e30ac3..ab00b60 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,10 +1,8 @@
1import argparse 1import argparse
2import itertools 2import itertools
3import math 3import math
4import os
5import datetime 4import datetime
6import logging 5import logging
7import json
8from pathlib import Path 6from pathlib import Path
9 7
10import torch 8import torch
@@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 23from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 24from training.optimization import get_one_cycle_schedule
25from training.lr import LRFinder
27from training.ti import patch_trainable_embeddings 26from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params 27from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params
29from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
@@ -173,6 +172,11 @@ def parse_args():
173 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 172 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
174 ) 173 )
175 parser.add_argument( 174 parser.add_argument(
175 "--find_lr",
176 action="store_true",
177 help="Automatically find a learning rate (no training).",
178 )
179 parser.add_argument(
176 "--learning_rate", 180 "--learning_rate",
177 type=float, 181 type=float,
178 default=1e-4, 182 default=1e-4,
@@ -225,7 +229,7 @@ def parse_args():
225 parser.add_argument( 229 parser.add_argument(
226 "--adam_weight_decay", 230 "--adam_weight_decay",
227 type=float, 231 type=float,
228 default=0, 232 default=1e-2,
229 help="Weight decay to use." 233 help="Weight decay to use."
230 ) 234 )
231 parser.add_argument( 235 parser.add_argument(
@@ -447,16 +451,23 @@ def main():
447 global_step_offset = args.global_step 451 global_step_offset = args.global_step
448 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 452 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
449 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 453 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
450 basepath.mkdir(parents=True, exist_ok=True)
451 454
452 accelerator = Accelerator( 455 if args.find_lr:
453 log_with=LoggerType.TENSORBOARD, 456 accelerator = Accelerator(
454 logging_dir=f"{basepath}", 457 gradient_accumulation_steps=args.gradient_accumulation_steps,
455 gradient_accumulation_steps=args.gradient_accumulation_steps, 458 mixed_precision=args.mixed_precision
456 mixed_precision=args.mixed_precision 459 )
457 ) 460 else:
461 basepath.mkdir(parents=True, exist_ok=True)
458 462
459 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 463 accelerator = Accelerator(
464 log_with=LoggerType.TENSORBOARD,
465 logging_dir=f"{basepath}",
466 gradient_accumulation_steps=args.gradient_accumulation_steps,
467 mixed_precision=args.mixed_precision
468 )
469
470 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
460 471
461 args.seed = args.seed or (torch.random.seed() >> 32) 472 args.seed = args.seed or (torch.random.seed() >> 32)
462 set_seed(args.seed) 473 set_seed(args.seed)
@@ -537,6 +548,9 @@ def main():
537 args.train_batch_size * accelerator.num_processes 548 args.train_batch_size * accelerator.num_processes
538 ) 549 )
539 550
551 if args.find_lr:
552 args.learning_rate = 1e2
553
540 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 554 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
541 if args.use_8bit_adam: 555 if args.use_8bit_adam:
542 try: 556 try:
@@ -671,7 +685,9 @@ def main():
671 685
672 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps 686 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
673 687
674 if args.lr_scheduler == "one_cycle": 688 if args.find_lr:
689 lr_scheduler = None
690 elif args.lr_scheduler == "one_cycle":
675 lr_scheduler = get_one_cycle_schedule( 691 lr_scheduler = get_one_cycle_schedule(
676 optimizer=optimizer, 692 optimizer=optimizer,
677 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 693 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
@@ -713,6 +729,63 @@ def main():
713 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 729 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
714 val_steps = num_val_steps_per_epoch * num_epochs 730 val_steps = num_val_steps_per_epoch * num_epochs
715 731
732 def loop(batch):
733 # Convert images to latent space
734 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
735 latents = latents * 0.18215
736
737 # Sample noise that we'll add to the latents
738 noise = torch.randn_like(latents)
739 bsz = latents.shape[0]
740 # Sample a random timestep for each image
741 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
742 (bsz,), device=latents.device)
743 timesteps = timesteps.long()
744
745 # Add noise to the latents according to the noise magnitude at each timestep
746 # (this is the forward diffusion process)
747 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
748
749 # Get the text embedding for conditioning
750 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
751 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
752
753 # Predict the noise residual
754 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
755
756 # Get the target for loss depending on the prediction type
757 if noise_scheduler.config.prediction_type == "epsilon":
758 target = noise
759 elif noise_scheduler.config.prediction_type == "v_prediction":
760 target = noise_scheduler.get_velocity(latents, noise, timesteps)
761 else:
762 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
763
764 if args.num_class_images != 0:
765 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
766 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
767 target, target_prior = torch.chunk(target, 2, dim=0)
768
769 # Compute instance loss
770 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
771
772 # Compute prior loss
773 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
774
775 # Add the prior loss to the instance loss.
776 loss = loss + args.prior_loss_weight * prior_loss
777 else:
778 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
779
780 acc = (model_pred == latents).float().mean()
781
782 return loss, acc, bsz
783
784 if args.find_lr:
785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop)
786 lr_finder.run()
787 quit()
788
716 # We need to initialize the trackers we use, and also store our configuration. 789 # We need to initialize the trackers we use, and also store our configuration.
717 # The trackers initializes automatically on the main process. 790 # The trackers initializes automatically on the main process.
718 if accelerator.is_main_process: 791 if accelerator.is_main_process:
@@ -786,54 +859,7 @@ def main():
786 859
787 for step, batch in enumerate(train_dataloader): 860 for step, batch in enumerate(train_dataloader):
788 with accelerator.accumulate(text_encoder): 861 with accelerator.accumulate(text_encoder):
789 # Convert images to latent space 862 loss, acc, bsz = loop(batch)
790 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
791 latents = latents * 0.18215
792
793 # Sample noise that we'll add to the latents
794 noise = torch.randn_like(latents)
795 bsz = latents.shape[0]
796 # Sample a random timestep for each image
797 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
798 (bsz,), device=latents.device)
799 timesteps = timesteps.long()
800
801 # Add noise to the latents according to the noise magnitude at each timestep
802 # (this is the forward diffusion process)
803 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
804
805 # Get the text embedding for conditioning
806 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
807 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
808
809 # Predict the noise residual
810 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
811
812 # Get the target for loss depending on the prediction type
813 if noise_scheduler.config.prediction_type == "epsilon":
814 target = noise
815 elif noise_scheduler.config.prediction_type == "v_prediction":
816 target = noise_scheduler.get_velocity(latents, noise, timesteps)
817 else:
818 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
819
820 if args.num_class_images != 0:
821 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
822 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
823 target, target_prior = torch.chunk(target, 2, dim=0)
824
825 # Compute instance loss
826 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
827
828 # Compute prior loss
829 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
830
831 # Add the prior loss to the instance loss.
832 loss = loss + args.prior_loss_weight * prior_loss
833 else:
834 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
835
836 acc = (model_pred == latents).float().mean()
837 863
838 accelerator.backward(loss) 864 accelerator.backward(loss)
839 865
@@ -873,33 +899,7 @@ def main():
873 899
874 with torch.inference_mode(): 900 with torch.inference_mode():
875 for step, batch in enumerate(val_dataloader): 901 for step, batch in enumerate(val_dataloader):
876 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 902 loss, acc, bsz = loop(batch)
877 latents = latents * 0.18215
878
879 noise = torch.randn_like(latents)
880 bsz = latents.shape[0]
881 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
882 (bsz,), device=latents.device)
883 timesteps = timesteps.long()
884
885 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
886
887 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
888 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
889
890 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
891
892 # Get the target for loss depending on the prediction type
893 if noise_scheduler.config.prediction_type == "epsilon":
894 target = noise
895 elif noise_scheduler.config.prediction_type == "v_prediction":
896 target = noise_scheduler.get_velocity(latents, noise, timesteps)
897 else:
898 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
899
900 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
901
902 acc = (model_pred == latents).float().mean()
903 903
904 avg_loss_val.update(loss.detach_(), bsz) 904 avg_loss_val.update(loss.detach_(), bsz)
905 avg_acc_val.update(acc.detach_(), bsz) 905 avg_acc_val.update(acc.detach_(), bsz)
diff --git a/training/lr.py b/training/lr.py
new file mode 100644
index 0000000..dd37baa
--- /dev/null
+++ b/training/lr.py
@@ -0,0 +1,115 @@
1import numpy as np
2from torch.optim.lr_scheduler import LambdaLR
3from tqdm.auto import tqdm
4import matplotlib.pyplot as plt
5
6from training.util import AverageMeter
7
8
9class LRFinder():
10 def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn):
11 self.accelerator = accelerator
12 self.model = model
13 self.optimizer = optimizer
14 self.train_dataloader = train_dataloader
15 self.loss_fn = loss_fn
16
17 def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5):
18 best_loss = None
19 lrs = []
20 losses = []
21
22 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs)
23
24 progress_bar = tqdm(
25 range(num_epochs * num_steps),
26 disable=not self.accelerator.is_local_main_process,
27 dynamic_ncols=True
28 )
29 progress_bar.set_description("Epoch X / Y")
30
31 for epoch in range(num_epochs):
32 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
33
34 avg_loss = AverageMeter()
35
36 for step, batch in enumerate(self.train_dataloader):
37 with self.accelerator.accumulate(self.model):
38 loss, acc, bsz = self.loss_fn(batch)
39
40 self.accelerator.backward(loss)
41
42 self.optimizer.step()
43 self.optimizer.zero_grad(set_to_none=True)
44
45 avg_loss.update(loss.detach_(), bsz)
46
47 if step >= num_steps:
48 break
49
50 if self.accelerator.sync_gradients:
51 progress_bar.update(1)
52
53 lr_scheduler.step()
54
55 loss = avg_loss.avg.item()
56 if epoch == 0:
57 best_loss = loss
58 else:
59 if smooth_f > 0:
60 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
61 if loss < best_loss:
62 best_loss = loss
63
64 lr = lr_scheduler.get_last_lr()[0]
65
66 lrs.append(lr)
67 losses.append(loss)
68
69 progress_bar.set_postfix({
70 "loss": loss,
71 "best": best_loss,
72 "lr": lr,
73 })
74
75 if loss > diverge_th * best_loss:
76 print("Stopping early, the loss has diverged")
77 break
78
79 fig, ax = plt.subplots()
80 ax.plot(lrs, losses)
81
82 print("LR suggestion: steepest gradient")
83 min_grad_idx = None
84 try:
85 min_grad_idx = (np.gradient(np.array(losses))).argmin()
86 except ValueError:
87 print(
88 "Failed to compute the gradients, there might not be enough points."
89 )
90 if min_grad_idx is not None:
91 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
92 ax.scatter(
93 lrs[min_grad_idx],
94 losses[min_grad_idx],
95 s=75,
96 marker="o",
97 color="red",
98 zorder=3,
99 label="steepest gradient",
100 )
101 ax.legend()
102
103 ax.set_xscale("log")
104 ax.set_xlabel("Learning rate")
105 ax.set_ylabel("Loss")
106
107 if fig is not None:
108 plt.show()
109
110
111def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1):
112 def lr_lambda(current_epoch: int):
113 return (current_epoch / num_epochs) ** 5
114
115 return LambdaLR(optimizer, lr_lambda, last_epoch)