diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
| commit | f963d4cba5c4c6575d77be80621a40b615603ca3 (patch) | |
| tree | d60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 | |
| parent | Fixed TI decay (diff) | |
| download | textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2 textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip | |
Update
| -rw-r--r-- | train_dreambooth.py | 56 | ||||
| -rw-r--r-- | train_ti.py | 30 | ||||
| -rw-r--r-- | training/lr.py | 11 |
3 files changed, 63 insertions, 34 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 73d9935..ebcf802 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -87,6 +87,12 @@ def parse_args(): | |||
| 87 | help="A token to use as initializer word." | 87 | help="A token to use as initializer word." |
| 88 | ) | 88 | ) |
| 89 | parser.add_argument( | 89 | parser.add_argument( |
| 90 | "--num_vectors", | ||
| 91 | type=int, | ||
| 92 | nargs='*', | ||
| 93 | help="Number of vectors per embedding." | ||
| 94 | ) | ||
| 95 | parser.add_argument( | ||
| 90 | "--exclude_collections", | 96 | "--exclude_collections", |
| 91 | type=str, | 97 | type=str, |
| 92 | nargs='*', | 98 | nargs='*', |
| @@ -444,17 +450,29 @@ def parse_args(): | |||
| 444 | if args.project is None: | 450 | if args.project is None: |
| 445 | raise ValueError("You must specify --project") | 451 | raise ValueError("You must specify --project") |
| 446 | 452 | ||
| 447 | if isinstance(args.initializer_token, str): | ||
| 448 | args.initializer_token = [args.initializer_token] | ||
| 449 | |||
| 450 | if isinstance(args.placeholder_token, str): | 453 | if isinstance(args.placeholder_token, str): |
| 451 | args.placeholder_token = [args.placeholder_token] | 454 | args.placeholder_token = [args.placeholder_token] |
| 452 | 455 | ||
| 453 | if len(args.placeholder_token) == 0: | 456 | if len(args.placeholder_token) == 0: |
| 454 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] | 457 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
| 458 | |||
| 459 | if isinstance(args.initializer_token, str): | ||
| 460 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
| 461 | |||
| 462 | if len(args.initializer_token) == 0: | ||
| 463 | raise ValueError("You must specify --initializer_token") | ||
| 455 | 464 | ||
| 456 | if len(args.placeholder_token) != len(args.initializer_token): | 465 | if len(args.placeholder_token) != len(args.initializer_token): |
| 457 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 466 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
| 467 | |||
| 468 | if args.num_vectors is None: | ||
| 469 | args.num_vectors = 1 | ||
| 470 | |||
| 471 | if isinstance(args.num_vectors, int): | ||
| 472 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
| 473 | |||
| 474 | if len(args.placeholder_token) != len(args.num_vectors): | ||
| 475 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
| 458 | 476 | ||
| 459 | if isinstance(args.collection, str): | 477 | if isinstance(args.collection, str): |
| 460 | args.collection = [args.collection] | 478 | args.collection = [args.collection] |
| @@ -882,6 +900,18 @@ def main(): | |||
| 882 | finally: | 900 | finally: |
| 883 | pass | 901 | pass |
| 884 | 902 | ||
| 903 | def on_before_optimize(): | ||
| 904 | if accelerator.sync_gradients: | ||
| 905 | params_to_clip = [unet.parameters()] | ||
| 906 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | ||
| 907 | params_to_clip.append(text_encoder.parameters()) | ||
| 908 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) | ||
| 909 | |||
| 910 | @torch.no_grad() | ||
| 911 | def on_after_optimize(lr: float): | ||
| 912 | if not args.train_text_encoder: | ||
| 913 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | ||
| 914 | |||
| 885 | loop = partial( | 915 | loop = partial( |
| 886 | loss_step, | 916 | loss_step, |
| 887 | vae, | 917 | vae, |
| @@ -915,10 +945,12 @@ def main(): | |||
| 915 | loop, | 945 | loop, |
| 916 | on_train=tokenizer.train, | 946 | on_train=tokenizer.train, |
| 917 | on_eval=tokenizer.eval, | 947 | on_eval=tokenizer.eval, |
| 948 | on_before_optimize=on_before_optimize, | ||
| 949 | on_after_optimize=on_after_optimize, | ||
| 918 | ) | 950 | ) |
| 919 | lr_finder.run(end_lr=1e2) | 951 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 920 | 952 | ||
| 921 | plt.savefig(basepath.joinpath("lr.png")) | 953 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 922 | plt.close() | 954 | plt.close() |
| 923 | 955 | ||
| 924 | quit() | 956 | quit() |
| @@ -999,13 +1031,7 @@ def main(): | |||
| 999 | 1031 | ||
| 1000 | accelerator.backward(loss) | 1032 | accelerator.backward(loss) |
| 1001 | 1033 | ||
| 1002 | if accelerator.sync_gradients: | 1034 | on_before_optimize() |
| 1003 | params_to_clip = ( | ||
| 1004 | itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
| 1005 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | ||
| 1006 | else unet.parameters() | ||
| 1007 | ) | ||
| 1008 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
| 1009 | 1035 | ||
| 1010 | optimizer.step() | 1036 | optimizer.step() |
| 1011 | if not accelerator.optimizer_step_was_skipped: | 1037 | if not accelerator.optimizer_step_was_skipped: |
| @@ -1019,6 +1045,8 @@ def main(): | |||
| 1019 | 1045 | ||
| 1020 | # Checks if the accelerator has performed an optimization step behind the scenes | 1046 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 1021 | if accelerator.sync_gradients: | 1047 | if accelerator.sync_gradients: |
| 1048 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 1049 | |||
| 1022 | local_progress_bar.update(1) | 1050 | local_progress_bar.update(1) |
| 1023 | global_progress_bar.update(1) | 1051 | global_progress_bar.update(1) |
| 1024 | 1052 | ||
diff --git a/train_ti.py b/train_ti.py index 890c465..9ec5cfb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -452,27 +452,27 @@ def parse_args(): | |||
| 452 | if args.project is None: | 452 | if args.project is None: |
| 453 | raise ValueError("You must specify --project") | 453 | raise ValueError("You must specify --project") |
| 454 | 454 | ||
| 455 | if isinstance(args.initializer_token, str): | ||
| 456 | args.initializer_token = [args.initializer_token] | ||
| 457 | |||
| 458 | if len(args.initializer_token) == 0: | ||
| 459 | raise ValueError("You must specify --initializer_token") | ||
| 460 | |||
| 461 | if isinstance(args.placeholder_token, str): | 455 | if isinstance(args.placeholder_token, str): |
| 462 | args.placeholder_token = [args.placeholder_token] | 456 | args.placeholder_token = [args.placeholder_token] |
| 463 | 457 | ||
| 464 | if len(args.placeholder_token) == 0: | 458 | if len(args.placeholder_token) == 0: |
| 465 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 459 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
| 466 | 460 | ||
| 461 | if isinstance(args.initializer_token, str): | ||
| 462 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
| 463 | |||
| 464 | if len(args.initializer_token) == 0: | ||
| 465 | raise ValueError("You must specify --initializer_token") | ||
| 466 | |||
| 467 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 468 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
| 469 | |||
| 467 | if args.num_vectors is None: | 470 | if args.num_vectors is None: |
| 468 | args.num_vectors = 1 | 471 | args.num_vectors = 1 |
| 469 | 472 | ||
| 470 | if isinstance(args.num_vectors, int): | 473 | if isinstance(args.num_vectors, int): |
| 471 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 474 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) |
| 472 | 475 | ||
| 473 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 474 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
| 475 | |||
| 476 | if len(args.placeholder_token) != len(args.num_vectors): | 476 | if len(args.placeholder_token) != len(args.num_vectors): |
| 477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | 477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") |
| 478 | 478 | ||
| @@ -867,7 +867,7 @@ def main(): | |||
| 867 | pass | 867 | pass |
| 868 | 868 | ||
| 869 | @torch.no_grad() | 869 | @torch.no_grad() |
| 870 | def on_clip(lr): | 870 | def on_after_optimize(lr: float): |
| 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) |
| 872 | 872 | ||
| 873 | loop = partial( | 873 | loop = partial( |
| @@ -904,7 +904,7 @@ def main(): | |||
| 904 | loop, | 904 | loop, |
| 905 | on_train=on_train, | 905 | on_train=on_train, |
| 906 | on_eval=on_eval, | 906 | on_eval=on_eval, |
| 907 | on_clip=on_clip, | 907 | on_after_optimize=on_after_optimize, |
| 908 | ) | 908 | ) |
| 909 | lr_finder.run(num_epochs=100, end_lr=1e3) | 909 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 910 | 910 | ||
| @@ -985,12 +985,8 @@ def main(): | |||
| 985 | 985 | ||
| 986 | accelerator.backward(loss) | 986 | accelerator.backward(loss) |
| 987 | 987 | ||
| 988 | if accelerator.sync_gradients: | ||
| 989 | on_clip(lr_scheduler.get_last_lr()[0]) | ||
| 990 | |||
| 991 | optimizer.step() | 988 | optimizer.step() |
| 992 | if not accelerator.optimizer_step_was_skipped: | 989 | lr_scheduler.step() |
| 993 | lr_scheduler.step() | ||
| 994 | optimizer.zero_grad(set_to_none=True) | 990 | optimizer.zero_grad(set_to_none=True) |
| 995 | 991 | ||
| 996 | avg_loss.update(loss.detach_(), bsz) | 992 | avg_loss.update(loss.detach_(), bsz) |
| @@ -998,6 +994,8 @@ def main(): | |||
| 998 | 994 | ||
| 999 | # Checks if the accelerator has performed an optimization step behind the scenes | 995 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 1000 | if accelerator.sync_gradients: | 996 | if accelerator.sync_gradients: |
| 997 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 998 | |||
| 1001 | if args.use_ema: | 999 | if args.use_ema: |
| 1002 | ema_embeddings.step( | 1000 | ema_embeddings.step( |
| 1003 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 1001 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
diff --git a/training/lr.py b/training/lr.py index 01f7f5e..84e30a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -26,7 +26,8 @@ class LRFinder(): | |||
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
| 29 | on_clip: Callable[[float], None] = noop, | 29 | on_before_optimize: Callable[[], None] = noop, |
| 30 | on_after_optimize: Callable[[float], None] = noop, | ||
| 30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 31 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
| 31 | ): | 32 | ): |
| 32 | self.accelerator = accelerator | 33 | self.accelerator = accelerator |
| @@ -36,7 +37,8 @@ class LRFinder(): | |||
| 36 | self.val_dataloader = val_dataloader | 37 | self.val_dataloader = val_dataloader |
| 37 | self.loss_fn = loss_fn | 38 | self.loss_fn = loss_fn |
| 38 | self.on_train = on_train | 39 | self.on_train = on_train |
| 39 | self.on_clip = on_clip | 40 | self.on_before_optimize = on_before_optimize |
| 41 | self.on_after_optimize = on_after_optimize | ||
| 40 | self.on_eval = on_eval | 42 | self.on_eval = on_eval |
| 41 | 43 | ||
| 42 | # self.model_state = copy.deepcopy(model.state_dict()) | 44 | # self.model_state = copy.deepcopy(model.state_dict()) |
| @@ -94,14 +96,15 @@ class LRFinder(): | |||
| 94 | 96 | ||
| 95 | self.accelerator.backward(loss) | 97 | self.accelerator.backward(loss) |
| 96 | 98 | ||
| 97 | if self.accelerator.sync_gradients: | 99 | self.on_before_optimize() |
| 98 | self.on_clip(lr_scheduler.get_last_lr()[0]) | ||
| 99 | 100 | ||
| 100 | self.optimizer.step() | 101 | self.optimizer.step() |
| 101 | lr_scheduler.step() | 102 | lr_scheduler.step() |
| 102 | self.optimizer.zero_grad(set_to_none=True) | 103 | self.optimizer.zero_grad(set_to_none=True) |
| 103 | 104 | ||
| 104 | if self.accelerator.sync_gradients: | 105 | if self.accelerator.sync_gradients: |
| 106 | self.on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 107 | |||
| 105 | progress_bar.update(1) | 108 | progress_bar.update(1) |
| 106 | 109 | ||
| 107 | self.model.eval() | 110 | self.model.eval() |
