summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-12 13:50:22 +0100
committerVolpeon <git@volpeon.ink>2023-01-12 13:50:22 +0100
commitf963d4cba5c4c6575d77be80621a40b615603ca3 (patch)
treed60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 /train_dreambooth.py
parentFixed TI decay (diff)
downloadtextual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz
textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2
textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py56
1 files changed, 42 insertions, 14 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 73d9935..ebcf802 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -87,6 +87,12 @@ def parse_args():
87 help="A token to use as initializer word." 87 help="A token to use as initializer word."
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(
90 "--num_vectors",
91 type=int,
92 nargs='*',
93 help="Number of vectors per embedding."
94 )
95 parser.add_argument(
90 "--exclude_collections", 96 "--exclude_collections",
91 type=str, 97 type=str,
92 nargs='*', 98 nargs='*',
@@ -444,17 +450,29 @@ def parse_args():
444 if args.project is None: 450 if args.project is None:
445 raise ValueError("You must specify --project") 451 raise ValueError("You must specify --project")
446 452
447 if isinstance(args.initializer_token, str):
448 args.initializer_token = [args.initializer_token]
449
450 if isinstance(args.placeholder_token, str): 453 if isinstance(args.placeholder_token, str):
451 args.placeholder_token = [args.placeholder_token] 454 args.placeholder_token = [args.placeholder_token]
452 455
453 if len(args.placeholder_token) == 0: 456 if len(args.placeholder_token) == 0:
454 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] 457 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
458
459 if isinstance(args.initializer_token, str):
460 args.initializer_token = [args.initializer_token] * len(args.placeholder_token)
461
462 if len(args.initializer_token) == 0:
463 raise ValueError("You must specify --initializer_token")
455 464
456 if len(args.placeholder_token) != len(args.initializer_token): 465 if len(args.placeholder_token) != len(args.initializer_token):
457 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 466 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
467
468 if args.num_vectors is None:
469 args.num_vectors = 1
470
471 if isinstance(args.num_vectors, int):
472 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
473
474 if len(args.placeholder_token) != len(args.num_vectors):
475 raise ValueError("--placeholder_token and --num_vectors must have the same number of items")
458 476
459 if isinstance(args.collection, str): 477 if isinstance(args.collection, str):
460 args.collection = [args.collection] 478 args.collection = [args.collection]
@@ -882,6 +900,18 @@ def main():
882 finally: 900 finally:
883 pass 901 pass
884 902
903 def on_before_optimize():
904 if accelerator.sync_gradients:
905 params_to_clip = [unet.parameters()]
906 if args.train_text_encoder and epoch < args.train_text_encoder_epochs:
907 params_to_clip.append(text_encoder.parameters())
908 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm)
909
910 @torch.no_grad()
911 def on_after_optimize(lr: float):
912 if not args.train_text_encoder:
913 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr))
914
885 loop = partial( 915 loop = partial(
886 loss_step, 916 loss_step,
887 vae, 917 vae,
@@ -915,10 +945,12 @@ def main():
915 loop, 945 loop,
916 on_train=tokenizer.train, 946 on_train=tokenizer.train,
917 on_eval=tokenizer.eval, 947 on_eval=tokenizer.eval,
948 on_before_optimize=on_before_optimize,
949 on_after_optimize=on_after_optimize,
918 ) 950 )
919 lr_finder.run(end_lr=1e2) 951 lr_finder.run(num_epochs=100, end_lr=1e3)
920 952
921 plt.savefig(basepath.joinpath("lr.png")) 953 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
922 plt.close() 954 plt.close()
923 955
924 quit() 956 quit()
@@ -999,13 +1031,7 @@ def main():
999 1031
1000 accelerator.backward(loss) 1032 accelerator.backward(loss)
1001 1033
1002 if accelerator.sync_gradients: 1034 on_before_optimize()
1003 params_to_clip = (
1004 itertools.chain(unet.parameters(), text_encoder.parameters())
1005 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
1006 else unet.parameters()
1007 )
1008 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1009 1035
1010 optimizer.step() 1036 optimizer.step()
1011 if not accelerator.optimizer_step_was_skipped: 1037 if not accelerator.optimizer_step_was_skipped:
@@ -1019,6 +1045,8 @@ def main():
1019 1045
1020 # Checks if the accelerator has performed an optimization step behind the scenes 1046 # Checks if the accelerator has performed an optimization step behind the scenes
1021 if accelerator.sync_gradients: 1047 if accelerator.sync_gradients:
1048 on_after_optimize(lr_scheduler.get_last_lr()[0])
1049
1022 local_progress_bar.update(1) 1050 local_progress_bar.update(1)
1023 global_progress_bar.update(1) 1051 global_progress_bar.update(1)
1024 1052