summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-20 14:26:17 +0100
committerVolpeon <git@volpeon.ink>2023-01-20 14:26:17 +0100
commit3575d041f1507811b577fd2c653171fb51c0a386 (patch)
tree702f9f1ae4eafc6f8ea06560c4de6bbe1c2acecb /training/optimization.py
parentMove Accelerator preparation into strategy (diff)
downloadtextual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.gz
textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.bz2
textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.zip
Restored LR finder
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py
index 6dee4bc..6c9a35d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -87,6 +87,15 @@ def get_one_cycle_schedule(
87 return LambdaLR(optimizer, lr_lambda, last_epoch) 87 return LambdaLR(optimizer, lr_lambda, last_epoch)
88 88
89 89
90def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1):
91 def lr_lambda(base_lr: float, current_step: int):
92 return (end_lr / base_lr) ** (current_step / num_training_steps)
93
94 lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups]
95
96 return LambdaLR(optimizer, lr_lambdas, last_epoch)
97
98
90def get_scheduler( 99def get_scheduler(
91 id: str, 100 id: str,
92 optimizer: torch.optim.Optimizer, 101 optimizer: torch.optim.Optimizer,
@@ -97,6 +106,7 @@ def get_scheduler(
97 annealing_func: Literal["cos", "half_cos", "linear"] = "cos", 106 annealing_func: Literal["cos", "half_cos", "linear"] = "cos",
98 warmup_exp: int = 1, 107 warmup_exp: int = 1,
99 annealing_exp: int = 1, 108 annealing_exp: int = 1,
109 end_lr: float = 1e3,
100 cycles: int = 1, 110 cycles: int = 1,
101 train_epochs: int = 100, 111 train_epochs: int = 100,
102 warmup_epochs: int = 10, 112 warmup_epochs: int = 10,
@@ -117,6 +127,15 @@ def get_scheduler(
117 annealing_exp=annealing_exp, 127 annealing_exp=annealing_exp,
118 min_lr=min_lr, 128 min_lr=min_lr,
119 ) 129 )
130 elif id == "exponential_growth":
131 if cycles is None:
132 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
133
134 lr_scheduler = get_exponential_growing_schedule(
135 optimizer=optimizer,
136 end_lr=end_lr,
137 num_training_steps=num_training_steps,
138 )
120 elif id == "cosine_with_restarts": 139 elif id == "cosine_with_restarts":
121 if cycles is None: 140 if cycles is None:
122 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 141 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))