summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 11:36:00 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 11:36:00 +0100
commitb7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch)
tree24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz
textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2
textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip
Fixed accuracy calc, other improvements
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8fd78f1..1ebcfe3 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -232,6 +232,30 @@ def parse_args():
232 help="Number of restart cycles in the lr scheduler (if supported)." 232 help="Number of restart cycles in the lr scheduler (if supported)."
233 ) 233 )
234 parser.add_argument( 234 parser.add_argument(
235 "--lr_warmup_func",
236 type=str,
237 default="cos",
238 help='Choose between ["linear", "cos"]'
239 )
240 parser.add_argument(
241 "--lr_warmup_exp",
242 type=int,
243 default=1,
244 help='If lr_warmup_func is "cos", exponent to modify the function'
245 )
246 parser.add_argument(
247 "--lr_annealing_func",
248 type=str,
249 default="cos",
250 help='Choose between ["linear", "half_cos", "cos"]'
251 )
252 parser.add_argument(
253 "--lr_annealing_exp",
254 type=int,
255 default=3,
256 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
257 )
258 parser.add_argument(
235 "--use_ema", 259 "--use_ema",
236 action="store_true", 260 action="store_true",
237 default=True, 261 default=True,
@@ -760,6 +784,10 @@ def main():
760 lr_scheduler = get_one_cycle_schedule( 784 lr_scheduler = get_one_cycle_schedule(
761 optimizer=optimizer, 785 optimizer=optimizer,
762 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 786 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
787 warmup=args.lr_warmup_func,
788 annealing=args.lr_annealing_func,
789 warmup_exp=args.lr_warmup_exp,
790 annealing_exp=args.lr_annealing_exp,
763 ) 791 )
764 elif args.lr_scheduler == "cosine_with_restarts": 792 elif args.lr_scheduler == "cosine_with_restarts":
765 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 793 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
@@ -913,7 +941,7 @@ def main():
913 else: 941 else:
914 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 942 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
915 943
916 acc = (model_pred == latents).float().mean() 944 acc = (model_pred == target).float().mean()
917 945
918 return loss, acc, bsz 946 return loss, acc, bsz
919 947