summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-12 13:50:22 +0100
committerVolpeon <git@volpeon.ink>2023-01-12 13:50:22 +0100
commitf963d4cba5c4c6575d77be80621a40b615603ca3 (patch)
treed60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 /train_ti.py
parentFixed TI decay (diff)
downloadtextual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz
textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2
textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/train_ti.py b/train_ti.py
index 890c465..9ec5cfb 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -452,27 +452,27 @@ def parse_args():
452 if args.project is None: 452 if args.project is None:
453 raise ValueError("You must specify --project") 453 raise ValueError("You must specify --project")
454 454
455 if isinstance(args.initializer_token, str):
456 args.initializer_token = [args.initializer_token]
457
458 if len(args.initializer_token) == 0:
459 raise ValueError("You must specify --initializer_token")
460
461 if isinstance(args.placeholder_token, str): 455 if isinstance(args.placeholder_token, str):
462 args.placeholder_token = [args.placeholder_token] 456 args.placeholder_token = [args.placeholder_token]
463 457
464 if len(args.placeholder_token) == 0: 458 if len(args.placeholder_token) == 0:
465 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 459 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
466 460
461 if isinstance(args.initializer_token, str):
462 args.initializer_token = [args.initializer_token] * len(args.placeholder_token)
463
464 if len(args.initializer_token) == 0:
465 raise ValueError("You must specify --initializer_token")
466
467 if len(args.placeholder_token) != len(args.initializer_token):
468 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
469
467 if args.num_vectors is None: 470 if args.num_vectors is None:
468 args.num_vectors = 1 471 args.num_vectors = 1
469 472
470 if isinstance(args.num_vectors, int): 473 if isinstance(args.num_vectors, int):
471 args.num_vectors = [args.num_vectors] * len(args.initializer_token) 474 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
472 475
473 if len(args.placeholder_token) != len(args.initializer_token):
474 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
475
476 if len(args.placeholder_token) != len(args.num_vectors): 476 if len(args.placeholder_token) != len(args.num_vectors):
477 raise ValueError("--placeholder_token and --num_vectors must have the same number of items") 477 raise ValueError("--placeholder_token and --num_vectors must have the same number of items")
478 478
@@ -867,7 +867,7 @@ def main():
867 pass 867 pass
868 868
869 @torch.no_grad() 869 @torch.no_grad()
870 def on_clip(lr): 870 def on_after_optimize(lr: float):
871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) 871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr))
872 872
873 loop = partial( 873 loop = partial(
@@ -904,7 +904,7 @@ def main():
904 loop, 904 loop,
905 on_train=on_train, 905 on_train=on_train,
906 on_eval=on_eval, 906 on_eval=on_eval,
907 on_clip=on_clip, 907 on_after_optimize=on_after_optimize,
908 ) 908 )
909 lr_finder.run(num_epochs=100, end_lr=1e3) 909 lr_finder.run(num_epochs=100, end_lr=1e3)
910 910
@@ -985,12 +985,8 @@ def main():
985 985
986 accelerator.backward(loss) 986 accelerator.backward(loss)
987 987
988 if accelerator.sync_gradients:
989 on_clip(lr_scheduler.get_last_lr()[0])
990
991 optimizer.step() 988 optimizer.step()
992 if not accelerator.optimizer_step_was_skipped: 989 lr_scheduler.step()
993 lr_scheduler.step()
994 optimizer.zero_grad(set_to_none=True) 990 optimizer.zero_grad(set_to_none=True)
995 991
996 avg_loss.update(loss.detach_(), bsz) 992 avg_loss.update(loss.detach_(), bsz)
@@ -998,6 +994,8 @@ def main():
998 994
999 # Checks if the accelerator has performed an optimization step behind the scenes 995 # Checks if the accelerator has performed an optimization step behind the scenes
1000 if accelerator.sync_gradients: 996 if accelerator.sync_gradients:
997 on_after_optimize(lr_scheduler.get_last_lr()[0])
998
1001 if args.use_ema: 999 if args.use_ema:
1002 ema_embeddings.step( 1000 ema_embeddings.step(
1003 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 1001 text_encoder.text_model.embeddings.temp_token_embedding.parameters())