summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py56
-rw-r--r--train_ti.py30
-rw-r--r--training/lr.py11
3 files changed, 63 insertions, 34 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 73d9935..ebcf802 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -87,6 +87,12 @@ def parse_args():
87 help="A token to use as initializer word." 87 help="A token to use as initializer word."
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(
90 "--num_vectors",
91 type=int,
92 nargs='*',
93 help="Number of vectors per embedding."
94 )
95 parser.add_argument(
90 "--exclude_collections", 96 "--exclude_collections",
91 type=str, 97 type=str,
92 nargs='*', 98 nargs='*',
@@ -444,17 +450,29 @@ def parse_args():
444 if args.project is None: 450 if args.project is None:
445 raise ValueError("You must specify --project") 451 raise ValueError("You must specify --project")
446 452
447 if isinstance(args.initializer_token, str):
448 args.initializer_token = [args.initializer_token]
449
450 if isinstance(args.placeholder_token, str): 453 if isinstance(args.placeholder_token, str):
451 args.placeholder_token = [args.placeholder_token] 454 args.placeholder_token = [args.placeholder_token]
452 455
453 if len(args.placeholder_token) == 0: 456 if len(args.placeholder_token) == 0:
454 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] 457 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
458
459 if isinstance(args.initializer_token, str):
460 args.initializer_token = [args.initializer_token] * len(args.placeholder_token)
461
462 if len(args.initializer_token) == 0:
463 raise ValueError("You must specify --initializer_token")
455 464
456 if len(args.placeholder_token) != len(args.initializer_token): 465 if len(args.placeholder_token) != len(args.initializer_token):
457 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 466 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
467
468 if args.num_vectors is None:
469 args.num_vectors = 1
470
471 if isinstance(args.num_vectors, int):
472 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
473
474 if len(args.placeholder_token) != len(args.num_vectors):
475 raise ValueError("--placeholder_token and --num_vectors must have the same number of items")
458 476
459 if isinstance(args.collection, str): 477 if isinstance(args.collection, str):
460 args.collection = [args.collection] 478 args.collection = [args.collection]
@@ -882,6 +900,18 @@ def main():
882 finally: 900 finally:
883 pass 901 pass
884 902
903 def on_before_optimize():
904 if accelerator.sync_gradients:
905 params_to_clip = [unet.parameters()]
906 if args.train_text_encoder and epoch < args.train_text_encoder_epochs:
907 params_to_clip.append(text_encoder.parameters())
908 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm)
909
910 @torch.no_grad()
911 def on_after_optimize(lr: float):
912 if not args.train_text_encoder:
913 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr))
914
885 loop = partial( 915 loop = partial(
886 loss_step, 916 loss_step,
887 vae, 917 vae,
@@ -915,10 +945,12 @@ def main():
915 loop, 945 loop,
916 on_train=tokenizer.train, 946 on_train=tokenizer.train,
917 on_eval=tokenizer.eval, 947 on_eval=tokenizer.eval,
948 on_before_optimize=on_before_optimize,
949 on_after_optimize=on_after_optimize,
918 ) 950 )
919 lr_finder.run(end_lr=1e2) 951 lr_finder.run(num_epochs=100, end_lr=1e3)
920 952
921 plt.savefig(basepath.joinpath("lr.png")) 953 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
922 plt.close() 954 plt.close()
923 955
924 quit() 956 quit()
@@ -999,13 +1031,7 @@ def main():
999 1031
1000 accelerator.backward(loss) 1032 accelerator.backward(loss)
1001 1033
1002 if accelerator.sync_gradients: 1034 on_before_optimize()
1003 params_to_clip = (
1004 itertools.chain(unet.parameters(), text_encoder.parameters())
1005 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
1006 else unet.parameters()
1007 )
1008 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1009 1035
1010 optimizer.step() 1036 optimizer.step()
1011 if not accelerator.optimizer_step_was_skipped: 1037 if not accelerator.optimizer_step_was_skipped:
@@ -1019,6 +1045,8 @@ def main():
1019 1045
1020 # Checks if the accelerator has performed an optimization step behind the scenes 1046 # Checks if the accelerator has performed an optimization step behind the scenes
1021 if accelerator.sync_gradients: 1047 if accelerator.sync_gradients:
1048 on_after_optimize(lr_scheduler.get_last_lr()[0])
1049
1022 local_progress_bar.update(1) 1050 local_progress_bar.update(1)
1023 global_progress_bar.update(1) 1051 global_progress_bar.update(1)
1024 1052
diff --git a/train_ti.py b/train_ti.py
index 890c465..9ec5cfb 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -452,27 +452,27 @@ def parse_args():
452 if args.project is None: 452 if args.project is None:
453 raise ValueError("You must specify --project") 453 raise ValueError("You must specify --project")
454 454
455 if isinstance(args.initializer_token, str):
456 args.initializer_token = [args.initializer_token]
457
458 if len(args.initializer_token) == 0:
459 raise ValueError("You must specify --initializer_token")
460
461 if isinstance(args.placeholder_token, str): 455 if isinstance(args.placeholder_token, str):
462 args.placeholder_token = [args.placeholder_token] 456 args.placeholder_token = [args.placeholder_token]
463 457
464 if len(args.placeholder_token) == 0: 458 if len(args.placeholder_token) == 0:
465 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 459 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
466 460
461 if isinstance(args.initializer_token, str):
462 args.initializer_token = [args.initializer_token] * len(args.placeholder_token)
463
464 if len(args.initializer_token) == 0:
465 raise ValueError("You must specify --initializer_token")
466
467 if len(args.placeholder_token) != len(args.initializer_token):
468 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
469
467 if args.num_vectors is None: 470 if args.num_vectors is None:
468 args.num_vectors = 1 471 args.num_vectors = 1
469 472
470 if isinstance(args.num_vectors, int): 473 if isinstance(args.num_vectors, int):
471 args.num_vectors = [args.num_vectors] * len(args.initializer_token) 474 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
472 475
473 if len(args.placeholder_token) != len(args.initializer_token):
474 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
475
476 if len(args.placeholder_token) != len(args.num_vectors): 476 if len(args.placeholder_token) != len(args.num_vectors):
477 raise ValueError("--placeholder_token and --num_vectors must have the same number of items") 477 raise ValueError("--placeholder_token and --num_vectors must have the same number of items")
478 478
@@ -867,7 +867,7 @@ def main():
867 pass 867 pass
868 868
869 @torch.no_grad() 869 @torch.no_grad()
870 def on_clip(lr): 870 def on_after_optimize(lr: float):
871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) 871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr))
872 872
873 loop = partial( 873 loop = partial(
@@ -904,7 +904,7 @@ def main():
904 loop, 904 loop,
905 on_train=on_train, 905 on_train=on_train,
906 on_eval=on_eval, 906 on_eval=on_eval,
907 on_clip=on_clip, 907 on_after_optimize=on_after_optimize,
908 ) 908 )
909 lr_finder.run(num_epochs=100, end_lr=1e3) 909 lr_finder.run(num_epochs=100, end_lr=1e3)
910 910
@@ -985,12 +985,8 @@ def main():
985 985
986 accelerator.backward(loss) 986 accelerator.backward(loss)
987 987
988 if accelerator.sync_gradients:
989 on_clip(lr_scheduler.get_last_lr()[0])
990
991 optimizer.step() 988 optimizer.step()
992 if not accelerator.optimizer_step_was_skipped: 989 lr_scheduler.step()
993 lr_scheduler.step()
994 optimizer.zero_grad(set_to_none=True) 990 optimizer.zero_grad(set_to_none=True)
995 991
996 avg_loss.update(loss.detach_(), bsz) 992 avg_loss.update(loss.detach_(), bsz)
@@ -998,6 +994,8 @@ def main():
998 994
999 # Checks if the accelerator has performed an optimization step behind the scenes 995 # Checks if the accelerator has performed an optimization step behind the scenes
1000 if accelerator.sync_gradients: 996 if accelerator.sync_gradients:
997 on_after_optimize(lr_scheduler.get_last_lr()[0])
998
1001 if args.use_ema: 999 if args.use_ema:
1002 ema_embeddings.step( 1000 ema_embeddings.step(
1003 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 1001 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
diff --git a/training/lr.py b/training/lr.py
index 01f7f5e..84e30a0 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -26,7 +26,8 @@ class LRFinder():
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[float], None] = noop, 29 on_before_optimize: Callable[[], None] = noop,
30 on_after_optimize: Callable[[float], None] = noop,
30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 31 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 32 ):
32 self.accelerator = accelerator 33 self.accelerator = accelerator
@@ -36,7 +37,8 @@ class LRFinder():
36 self.val_dataloader = val_dataloader 37 self.val_dataloader = val_dataloader
37 self.loss_fn = loss_fn 38 self.loss_fn = loss_fn
38 self.on_train = on_train 39 self.on_train = on_train
39 self.on_clip = on_clip 40 self.on_before_optimize = on_before_optimize
41 self.on_after_optimize = on_after_optimize
40 self.on_eval = on_eval 42 self.on_eval = on_eval
41 43
42 # self.model_state = copy.deepcopy(model.state_dict()) 44 # self.model_state = copy.deepcopy(model.state_dict())
@@ -94,14 +96,15 @@ class LRFinder():
94 96
95 self.accelerator.backward(loss) 97 self.accelerator.backward(loss)
96 98
97 if self.accelerator.sync_gradients: 99 self.on_before_optimize()
98 self.on_clip(lr_scheduler.get_last_lr()[0])
99 100
100 self.optimizer.step() 101 self.optimizer.step()
101 lr_scheduler.step() 102 lr_scheduler.step()
102 self.optimizer.zero_grad(set_to_none=True) 103 self.optimizer.zero_grad(set_to_none=True)
103 104
104 if self.accelerator.sync_gradients: 105 if self.accelerator.sync_gradients:
106 self.on_after_optimize(lr_scheduler.get_last_lr()[0])
107
105 progress_bar.update(1) 108 progress_bar.update(1)
106 109
107 self.model.eval() 110 self.model.eval()