diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 15:52:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 15:52:43 +0100 |
commit | 6c8cffe28baeafac77d047ff3f8ded9418033e2f (patch) | |
tree | 807c527deb1b15ef795f5cd8a7682151c69a037e /training/optimization.py | |
parent | Pad dataset if len(items) < batch_size (diff) | |
download | textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.gz textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.bz2 textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.zip |
More training adjustments
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -49,8 +49,8 @@ def get_one_cycle_schedule( | |||
49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
50 | warmup_exp: int = 1, | 50 | warmup_exp: int = 1, |
51 | annealing_exp: int = 1, | 51 | annealing_exp: int = 1, |
52 | min_lr: int = 0.04, | 52 | min_lr: float = 0.04, |
53 | mid_point: int = 0.3, | 53 | mid_point: float = 0.3, |
54 | last_epoch: int = -1 | 54 | last_epoch: int = -1 |
55 | ): | 55 | ): |
56 | if warmup == "linear": | 56 | if warmup == "linear": |
@@ -91,10 +91,10 @@ def get_scheduler( | |||
91 | id: str, | 91 | id: str, |
92 | optimizer: torch.optim.Optimizer, | 92 | optimizer: torch.optim.Optimizer, |
93 | num_training_steps_per_epoch: int, | 93 | num_training_steps_per_epoch: int, |
94 | gradient_accumulation_steps: int, | 94 | gradient_accumulation_steps: int = 1, |
95 | min_lr: float = 0.04, | 95 | min_lr: float = 0.04, |
96 | warmup_func: str = "cos", | 96 | warmup_func: Literal["cos", "linear"] = "cos", |
97 | annealing_func: str = "cos", | 97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
98 | warmup_exp: int = 1, | 98 | warmup_exp: int = 1, |
99 | annealing_exp: int = 1, | 99 | annealing_exp: int = 1, |
100 | cycles: int = 1, | 100 | cycles: int = 1, |