diff options
author | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
commit | f963d4cba5c4c6575d77be80621a40b615603ca3 (patch) | |
tree | d60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 | |
parent | Fixed TI decay (diff) | |
download | textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2 textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip |
Update
-rw-r--r-- | train_dreambooth.py | 56 | ||||
-rw-r--r-- | train_ti.py | 30 | ||||
-rw-r--r-- | training/lr.py | 11 |
3 files changed, 63 insertions, 34 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 73d9935..ebcf802 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -87,6 +87,12 @@ def parse_args(): | |||
87 | help="A token to use as initializer word." | 87 | help="A token to use as initializer word." |
88 | ) | 88 | ) |
89 | parser.add_argument( | 89 | parser.add_argument( |
90 | "--num_vectors", | ||
91 | type=int, | ||
92 | nargs='*', | ||
93 | help="Number of vectors per embedding." | ||
94 | ) | ||
95 | parser.add_argument( | ||
90 | "--exclude_collections", | 96 | "--exclude_collections", |
91 | type=str, | 97 | type=str, |
92 | nargs='*', | 98 | nargs='*', |
@@ -444,17 +450,29 @@ def parse_args(): | |||
444 | if args.project is None: | 450 | if args.project is None: |
445 | raise ValueError("You must specify --project") | 451 | raise ValueError("You must specify --project") |
446 | 452 | ||
447 | if isinstance(args.initializer_token, str): | ||
448 | args.initializer_token = [args.initializer_token] | ||
449 | |||
450 | if isinstance(args.placeholder_token, str): | 453 | if isinstance(args.placeholder_token, str): |
451 | args.placeholder_token = [args.placeholder_token] | 454 | args.placeholder_token = [args.placeholder_token] |
452 | 455 | ||
453 | if len(args.placeholder_token) == 0: | 456 | if len(args.placeholder_token) == 0: |
454 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] | 457 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
458 | |||
459 | if isinstance(args.initializer_token, str): | ||
460 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
461 | |||
462 | if len(args.initializer_token) == 0: | ||
463 | raise ValueError("You must specify --initializer_token") | ||
455 | 464 | ||
456 | if len(args.placeholder_token) != len(args.initializer_token): | 465 | if len(args.placeholder_token) != len(args.initializer_token): |
457 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 466 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
467 | |||
468 | if args.num_vectors is None: | ||
469 | args.num_vectors = 1 | ||
470 | |||
471 | if isinstance(args.num_vectors, int): | ||
472 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
473 | |||
474 | if len(args.placeholder_token) != len(args.num_vectors): | ||
475 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
458 | 476 | ||
459 | if isinstance(args.collection, str): | 477 | if isinstance(args.collection, str): |
460 | args.collection = [args.collection] | 478 | args.collection = [args.collection] |
@@ -882,6 +900,18 @@ def main(): | |||
882 | finally: | 900 | finally: |
883 | pass | 901 | pass |
884 | 902 | ||
903 | def on_before_optimize(): | ||
904 | if accelerator.sync_gradients: | ||
905 | params_to_clip = [unet.parameters()] | ||
906 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | ||
907 | params_to_clip.append(text_encoder.parameters()) | ||
908 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) | ||
909 | |||
910 | @torch.no_grad() | ||
911 | def on_after_optimize(lr: float): | ||
912 | if not args.train_text_encoder: | ||
913 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | ||
914 | |||
885 | loop = partial( | 915 | loop = partial( |
886 | loss_step, | 916 | loss_step, |
887 | vae, | 917 | vae, |
@@ -915,10 +945,12 @@ def main(): | |||
915 | loop, | 945 | loop, |
916 | on_train=tokenizer.train, | 946 | on_train=tokenizer.train, |
917 | on_eval=tokenizer.eval, | 947 | on_eval=tokenizer.eval, |
948 | on_before_optimize=on_before_optimize, | ||
949 | on_after_optimize=on_after_optimize, | ||
918 | ) | 950 | ) |
919 | lr_finder.run(end_lr=1e2) | 951 | lr_finder.run(num_epochs=100, end_lr=1e3) |
920 | 952 | ||
921 | plt.savefig(basepath.joinpath("lr.png")) | 953 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
922 | plt.close() | 954 | plt.close() |
923 | 955 | ||
924 | quit() | 956 | quit() |
@@ -999,13 +1031,7 @@ def main(): | |||
999 | 1031 | ||
1000 | accelerator.backward(loss) | 1032 | accelerator.backward(loss) |
1001 | 1033 | ||
1002 | if accelerator.sync_gradients: | 1034 | on_before_optimize() |
1003 | params_to_clip = ( | ||
1004 | itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
1005 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | ||
1006 | else unet.parameters() | ||
1007 | ) | ||
1008 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
1009 | 1035 | ||
1010 | optimizer.step() | 1036 | optimizer.step() |
1011 | if not accelerator.optimizer_step_was_skipped: | 1037 | if not accelerator.optimizer_step_was_skipped: |
@@ -1019,6 +1045,8 @@ def main(): | |||
1019 | 1045 | ||
1020 | # Checks if the accelerator has performed an optimization step behind the scenes | 1046 | # Checks if the accelerator has performed an optimization step behind the scenes |
1021 | if accelerator.sync_gradients: | 1047 | if accelerator.sync_gradients: |
1048 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
1049 | |||
1022 | local_progress_bar.update(1) | 1050 | local_progress_bar.update(1) |
1023 | global_progress_bar.update(1) | 1051 | global_progress_bar.update(1) |
1024 | 1052 | ||
diff --git a/train_ti.py b/train_ti.py index 890c465..9ec5cfb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -452,27 +452,27 @@ def parse_args(): | |||
452 | if args.project is None: | 452 | if args.project is None: |
453 | raise ValueError("You must specify --project") | 453 | raise ValueError("You must specify --project") |
454 | 454 | ||
455 | if isinstance(args.initializer_token, str): | ||
456 | args.initializer_token = [args.initializer_token] | ||
457 | |||
458 | if len(args.initializer_token) == 0: | ||
459 | raise ValueError("You must specify --initializer_token") | ||
460 | |||
461 | if isinstance(args.placeholder_token, str): | 455 | if isinstance(args.placeholder_token, str): |
462 | args.placeholder_token = [args.placeholder_token] | 456 | args.placeholder_token = [args.placeholder_token] |
463 | 457 | ||
464 | if len(args.placeholder_token) == 0: | 458 | if len(args.placeholder_token) == 0: |
465 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 459 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
466 | 460 | ||
461 | if isinstance(args.initializer_token, str): | ||
462 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
463 | |||
464 | if len(args.initializer_token) == 0: | ||
465 | raise ValueError("You must specify --initializer_token") | ||
466 | |||
467 | if len(args.placeholder_token) != len(args.initializer_token): | ||
468 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
469 | |||
467 | if args.num_vectors is None: | 470 | if args.num_vectors is None: |
468 | args.num_vectors = 1 | 471 | args.num_vectors = 1 |
469 | 472 | ||
470 | if isinstance(args.num_vectors, int): | 473 | if isinstance(args.num_vectors, int): |
471 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 474 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) |
472 | 475 | ||
473 | if len(args.placeholder_token) != len(args.initializer_token): | ||
474 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
475 | |||
476 | if len(args.placeholder_token) != len(args.num_vectors): | 476 | if len(args.placeholder_token) != len(args.num_vectors): |
477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | 477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") |
478 | 478 | ||
@@ -867,7 +867,7 @@ def main(): | |||
867 | pass | 867 | pass |
868 | 868 | ||
869 | @torch.no_grad() | 869 | @torch.no_grad() |
870 | def on_clip(lr): | 870 | def on_after_optimize(lr: float): |
871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) |
872 | 872 | ||
873 | loop = partial( | 873 | loop = partial( |
@@ -904,7 +904,7 @@ def main(): | |||
904 | loop, | 904 | loop, |
905 | on_train=on_train, | 905 | on_train=on_train, |
906 | on_eval=on_eval, | 906 | on_eval=on_eval, |
907 | on_clip=on_clip, | 907 | on_after_optimize=on_after_optimize, |
908 | ) | 908 | ) |
909 | lr_finder.run(num_epochs=100, end_lr=1e3) | 909 | lr_finder.run(num_epochs=100, end_lr=1e3) |
910 | 910 | ||
@@ -985,12 +985,8 @@ def main(): | |||
985 | 985 | ||
986 | accelerator.backward(loss) | 986 | accelerator.backward(loss) |
987 | 987 | ||
988 | if accelerator.sync_gradients: | ||
989 | on_clip(lr_scheduler.get_last_lr()[0]) | ||
990 | |||
991 | optimizer.step() | 988 | optimizer.step() |
992 | if not accelerator.optimizer_step_was_skipped: | 989 | lr_scheduler.step() |
993 | lr_scheduler.step() | ||
994 | optimizer.zero_grad(set_to_none=True) | 990 | optimizer.zero_grad(set_to_none=True) |
995 | 991 | ||
996 | avg_loss.update(loss.detach_(), bsz) | 992 | avg_loss.update(loss.detach_(), bsz) |
@@ -998,6 +994,8 @@ def main(): | |||
998 | 994 | ||
999 | # Checks if the accelerator has performed an optimization step behind the scenes | 995 | # Checks if the accelerator has performed an optimization step behind the scenes |
1000 | if accelerator.sync_gradients: | 996 | if accelerator.sync_gradients: |
997 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
998 | |||
1001 | if args.use_ema: | 999 | if args.use_ema: |
1002 | ema_embeddings.step( | 1000 | ema_embeddings.step( |
1003 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 1001 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
diff --git a/training/lr.py b/training/lr.py index 01f7f5e..84e30a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -26,7 +26,8 @@ class LRFinder(): | |||
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
29 | on_clip: Callable[[float], None] = noop, | 29 | on_before_optimize: Callable[[], None] = noop, |
30 | on_after_optimize: Callable[[float], None] = noop, | ||
30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 31 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
31 | ): | 32 | ): |
32 | self.accelerator = accelerator | 33 | self.accelerator = accelerator |
@@ -36,7 +37,8 @@ class LRFinder(): | |||
36 | self.val_dataloader = val_dataloader | 37 | self.val_dataloader = val_dataloader |
37 | self.loss_fn = loss_fn | 38 | self.loss_fn = loss_fn |
38 | self.on_train = on_train | 39 | self.on_train = on_train |
39 | self.on_clip = on_clip | 40 | self.on_before_optimize = on_before_optimize |
41 | self.on_after_optimize = on_after_optimize | ||
40 | self.on_eval = on_eval | 42 | self.on_eval = on_eval |
41 | 43 | ||
42 | # self.model_state = copy.deepcopy(model.state_dict()) | 44 | # self.model_state = copy.deepcopy(model.state_dict()) |
@@ -94,14 +96,15 @@ class LRFinder(): | |||
94 | 96 | ||
95 | self.accelerator.backward(loss) | 97 | self.accelerator.backward(loss) |
96 | 98 | ||
97 | if self.accelerator.sync_gradients: | 99 | self.on_before_optimize() |
98 | self.on_clip(lr_scheduler.get_last_lr()[0]) | ||
99 | 100 | ||
100 | self.optimizer.step() | 101 | self.optimizer.step() |
101 | lr_scheduler.step() | 102 | lr_scheduler.step() |
102 | self.optimizer.zero_grad(set_to_none=True) | 103 | self.optimizer.zero_grad(set_to_none=True) |
103 | 104 | ||
104 | if self.accelerator.sync_gradients: | 105 | if self.accelerator.sync_gradients: |
106 | self.on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
107 | |||
105 | progress_bar.update(1) | 108 | progress_bar.update(1) |
106 | 109 | ||
107 | self.model.eval() | 110 | self.model.eval() |