diff options
author | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
commit | b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch) | |
tree | 24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2 textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip |
Fixed accuracy calc, other improvements
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8fd78f1..1ebcfe3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -232,6 +232,30 @@ def parse_args(): | |||
232 | help="Number of restart cycles in the lr scheduler (if supported)." | 232 | help="Number of restart cycles in the lr scheduler (if supported)." |
233 | ) | 233 | ) |
234 | parser.add_argument( | 234 | parser.add_argument( |
235 | "--lr_warmup_func", | ||
236 | type=str, | ||
237 | default="cos", | ||
238 | help='Choose between ["linear", "cos"]' | ||
239 | ) | ||
240 | parser.add_argument( | ||
241 | "--lr_warmup_exp", | ||
242 | type=int, | ||
243 | default=1, | ||
244 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
245 | ) | ||
246 | parser.add_argument( | ||
247 | "--lr_annealing_func", | ||
248 | type=str, | ||
249 | default="cos", | ||
250 | help='Choose between ["linear", "half_cos", "cos"]' | ||
251 | ) | ||
252 | parser.add_argument( | ||
253 | "--lr_annealing_exp", | ||
254 | type=int, | ||
255 | default=3, | ||
256 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
257 | ) | ||
258 | parser.add_argument( | ||
235 | "--use_ema", | 259 | "--use_ema", |
236 | action="store_true", | 260 | action="store_true", |
237 | default=True, | 261 | default=True, |
@@ -760,6 +784,10 @@ def main(): | |||
760 | lr_scheduler = get_one_cycle_schedule( | 784 | lr_scheduler = get_one_cycle_schedule( |
761 | optimizer=optimizer, | 785 | optimizer=optimizer, |
762 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 786 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
787 | warmup=args.lr_warmup_func, | ||
788 | annealing=args.lr_annealing_func, | ||
789 | warmup_exp=args.lr_warmup_exp, | ||
790 | annealing_exp=args.lr_annealing_exp, | ||
763 | ) | 791 | ) |
764 | elif args.lr_scheduler == "cosine_with_restarts": | 792 | elif args.lr_scheduler == "cosine_with_restarts": |
765 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 793 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
@@ -913,7 +941,7 @@ def main(): | |||
913 | else: | 941 | else: |
914 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
915 | 943 | ||
916 | acc = (model_pred == latents).float().mean() | 944 | acc = (model_pred == target).float().mean() |
917 | 945 | ||
918 | return loss, acc, bsz | 946 | return loss, acc, bsz |
919 | 947 | ||