diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
| commit | b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch) | |
| tree | 24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /train_dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2 textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip | |
Fixed accuracy calc, other improvements
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8fd78f1..1ebcfe3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -232,6 +232,30 @@ def parse_args(): | |||
| 232 | help="Number of restart cycles in the lr scheduler (if supported)." | 232 | help="Number of restart cycles in the lr scheduler (if supported)." |
| 233 | ) | 233 | ) |
| 234 | parser.add_argument( | 234 | parser.add_argument( |
| 235 | "--lr_warmup_func", | ||
| 236 | type=str, | ||
| 237 | default="cos", | ||
| 238 | help='Choose between ["linear", "cos"]' | ||
| 239 | ) | ||
| 240 | parser.add_argument( | ||
| 241 | "--lr_warmup_exp", | ||
| 242 | type=int, | ||
| 243 | default=1, | ||
| 244 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 245 | ) | ||
| 246 | parser.add_argument( | ||
| 247 | "--lr_annealing_func", | ||
| 248 | type=str, | ||
| 249 | default="cos", | ||
| 250 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 251 | ) | ||
| 252 | parser.add_argument( | ||
| 253 | "--lr_annealing_exp", | ||
| 254 | type=int, | ||
| 255 | default=3, | ||
| 256 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 257 | ) | ||
| 258 | parser.add_argument( | ||
| 235 | "--use_ema", | 259 | "--use_ema", |
| 236 | action="store_true", | 260 | action="store_true", |
| 237 | default=True, | 261 | default=True, |
| @@ -760,6 +784,10 @@ def main(): | |||
| 760 | lr_scheduler = get_one_cycle_schedule( | 784 | lr_scheduler = get_one_cycle_schedule( |
| 761 | optimizer=optimizer, | 785 | optimizer=optimizer, |
| 762 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 786 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 787 | warmup=args.lr_warmup_func, | ||
| 788 | annealing=args.lr_annealing_func, | ||
| 789 | warmup_exp=args.lr_warmup_exp, | ||
| 790 | annealing_exp=args.lr_annealing_exp, | ||
| 763 | ) | 791 | ) |
| 764 | elif args.lr_scheduler == "cosine_with_restarts": | 792 | elif args.lr_scheduler == "cosine_with_restarts": |
| 765 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 793 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| @@ -913,7 +941,7 @@ def main(): | |||
| 913 | else: | 941 | else: |
| 914 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 915 | 943 | ||
| 916 | acc = (model_pred == latents).float().mean() | 944 | acc = (model_pred == target).float().mean() |
| 917 | 945 | ||
| 918 | return loss, acc, bsz | 946 | return loss, acc, bsz |
| 919 | 947 | ||
