diff options
| -rw-r--r-- | train_dreambooth.py | 32 | ||||
| -rw-r--r-- | train_lora.py | 18 | ||||
| -rw-r--r-- | train_ti.py | 20 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 26 | 
4 files changed, 40 insertions, 56 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d284346..c8f03ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py  | |||
| @@ -145,12 +145,6 @@ def parse_args(): | |||
| 145 | help="Tokens to create an alias for.", | 145 | help="Tokens to create an alias for.", | 
| 146 | ) | 146 | ) | 
| 147 | parser.add_argument( | 147 | parser.add_argument( | 
| 148 | "--inverted_initializer_tokens", | ||
| 149 | type=str, | ||
| 150 | nargs="*", | ||
| 151 | help="A token to use as initializer word.", | ||
| 152 | ) | ||
| 153 | parser.add_argument( | ||
| 154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 148 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 
| 155 | ) | 149 | ) | 
| 156 | parser.add_argument( | 150 | parser.add_argument( | 
| @@ -499,6 +493,15 @@ def parse_args(): | |||
| 499 | help="Embedding dropout probability.", | 493 | help="Embedding dropout probability.", | 
| 500 | ) | 494 | ) | 
| 501 | parser.add_argument( | 495 | parser.add_argument( | 
| 496 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." | ||
| 497 | ) | ||
| 498 | parser.add_argument( | ||
| 499 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." | ||
| 500 | ) | ||
| 501 | parser.add_argument( | ||
| 502 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." | ||
| 503 | ) | ||
| 504 | parser.add_argument( | ||
| 502 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | 505 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | 
| 503 | ) | 506 | ) | 
| 504 | parser.add_argument( | 507 | parser.add_argument( | 
| @@ -554,18 +557,6 @@ def parse_args(): | |||
| 554 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 557 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 
| 555 | ) | 558 | ) | 
| 556 | 559 | ||
| 557 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 558 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 559 | args.placeholder_tokens | ||
| 560 | ) | ||
| 561 | |||
| 562 | if ( | ||
| 563 | isinstance(args.inverted_initializer_tokens, list) | ||
| 564 | and len(args.inverted_initializer_tokens) != 0 | ||
| 565 | ): | ||
| 566 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 567 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 568 | |||
| 569 | if isinstance(args.num_vectors, int): | 560 | if isinstance(args.num_vectors, int): | 
| 570 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 
| 571 | 562 | ||
| @@ -875,6 +866,11 @@ def main(): | |||
| 875 | sample_num_batches=args.sample_batches, | 866 | sample_num_batches=args.sample_batches, | 
| 876 | sample_num_steps=args.sample_steps, | 867 | sample_num_steps=args.sample_steps, | 
| 877 | sample_image_size=args.sample_image_size, | 868 | sample_image_size=args.sample_image_size, | 
| 869 | placeholder_tokens=placeholder_tokens, | ||
| 870 | placeholder_token_ids=placeholder_token_ids, | ||
| 871 | use_emb_decay=args.use_emb_decay, | ||
| 872 | emb_decay_target=args.emb_decay_target, | ||
| 873 | emb_decay=args.emb_decay, | ||
| 878 | max_grad_norm=args.max_grad_norm, | 874 | max_grad_norm=args.max_grad_norm, | 
| 879 | ) | 875 | ) | 
| 880 | 876 | ||
diff --git a/train_lora.py b/train_lora.py index 1ff25ff..fbec009 100644 --- a/train_lora.py +++ b/train_lora.py  | |||
| @@ -158,12 +158,6 @@ def parse_args(): | |||
| 158 | help="Tokens to create an alias for.", | 158 | help="Tokens to create an alias for.", | 
| 159 | ) | 159 | ) | 
| 160 | parser.add_argument( | 160 | parser.add_argument( | 
| 161 | "--inverted_initializer_tokens", | ||
| 162 | type=str, | ||
| 163 | nargs="*", | ||
| 164 | help="A token to use as initializer word.", | ||
| 165 | ) | ||
| 166 | parser.add_argument( | ||
| 167 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 161 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 
| 168 | ) | 162 | ) | 
| 169 | parser.add_argument( | 163 | parser.add_argument( | 
| @@ -633,18 +627,6 @@ def parse_args(): | |||
| 633 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 627 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 
| 634 | ) | 628 | ) | 
| 635 | 629 | ||
| 636 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 637 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 638 | args.placeholder_tokens | ||
| 639 | ) | ||
| 640 | |||
| 641 | if ( | ||
| 642 | isinstance(args.inverted_initializer_tokens, list) | ||
| 643 | and len(args.inverted_initializer_tokens) != 0 | ||
| 644 | ): | ||
| 645 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 646 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 647 | |||
| 648 | if isinstance(args.num_vectors, int): | 630 | if isinstance(args.num_vectors, int): | 
| 649 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 631 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 
| 650 | 632 | ||
diff --git a/train_ti.py b/train_ti.py index 1dbd637..8c63493 100644 --- a/train_ti.py +++ b/train_ti.py  | |||
| @@ -112,12 +112,6 @@ def parse_args(): | |||
| 112 | help="Tokens to create an alias for.", | 112 | help="Tokens to create an alias for.", | 
| 113 | ) | 113 | ) | 
| 114 | parser.add_argument( | 114 | parser.add_argument( | 
| 115 | "--inverted_initializer_tokens", | ||
| 116 | type=str, | ||
| 117 | nargs="*", | ||
| 118 | help="A token to use as initializer word.", | ||
| 119 | ) | ||
| 120 | parser.add_argument( | ||
| 121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 115 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 
| 122 | ) | 116 | ) | 
| 123 | parser.add_argument( | 117 | parser.add_argument( | 
| @@ -545,18 +539,6 @@ def parse_args(): | |||
| 545 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 539 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 
| 546 | ) | 540 | ) | 
| 547 | 541 | ||
| 548 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 549 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 550 | args.placeholder_tokens | ||
| 551 | ) | ||
| 552 | |||
| 553 | if ( | ||
| 554 | isinstance(args.inverted_initializer_tokens, list) | ||
| 555 | and len(args.inverted_initializer_tokens) != 0 | ||
| 556 | ): | ||
| 557 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 558 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 559 | |||
| 560 | if isinstance(args.num_vectors, int): | 542 | if isinstance(args.num_vectors, int): | 
| 561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 543 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 
| 562 | 544 | ||
| @@ -872,7 +854,7 @@ def main(): | |||
| 872 | 854 | ||
| 873 | optimizer = create_optimizer( | 855 | optimizer = create_optimizer( | 
| 874 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 856 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 
| 875 | lr=learning_rate, | 857 | lr=args.learning_rate, | 
| 876 | ) | 858 | ) | 
| 877 | 859 | ||
| 878 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index dc19ba3..0f64747 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py  | |||
| @@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks( | |||
| 30 | sample_output_dir: Path, | 30 | sample_output_dir: Path, | 
| 31 | checkpoint_output_dir: Path, | 31 | checkpoint_output_dir: Path, | 
| 32 | seed: int, | 32 | seed: int, | 
| 33 | placeholder_tokens: list[str], | ||
| 34 | placeholder_token_ids: list[list[int]], | ||
| 33 | train_text_encoder_cycles: int, | 35 | train_text_encoder_cycles: int, | 
| 34 | text_encoder_unfreeze_last_n_layers: int = 2, | 36 | text_encoder_unfreeze_last_n_layers: int = 2, | 
| 37 | use_emb_decay: bool = False, | ||
| 38 | emb_decay_target: float = 0.4, | ||
| 39 | emb_decay: float = 1e-2, | ||
| 35 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, | 
| 36 | use_ema: bool = False, | 41 | use_ema: bool = False, | 
| 37 | ema_inv_gamma: float = 1.0, | 42 | ema_inv_gamma: float = 1.0, | 
| @@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks( | |||
| 112 | params_to_clip.append(text_encoder.parameters()) | 117 | params_to_clip.append(text_encoder.parameters()) | 
| 113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 118 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 
| 114 | 119 | ||
| 120 | if len(placeholder_tokens) != 0 and use_emb_decay: | ||
| 121 | params = [ | ||
| 122 | p | ||
| 123 | for p in text_encoder.text_model.embeddings.parameters() | ||
| 124 | if p.grad is not None | ||
| 125 | ] | ||
| 126 | return torch.stack(params) if len(params) != 0 else None | ||
| 127 | |||
| 115 | @torch.no_grad() | 128 | @torch.no_grad() | 
| 116 | def on_after_optimize(_, lrs: dict[str, float]): | 129 | def on_after_optimize(w, lrs: dict[str, float]): | 
| 117 | if ema_unet is not None: | 130 | if ema_unet is not None: | 
| 118 | ema_unet.step(unet.parameters()) | 131 | ema_unet.step(unet.parameters()) | 
| 119 | 132 | ||
| 133 | if w is not None and "emb" in lrs: | ||
| 134 | lr = lrs["emb"] | ||
| 135 | lambda_ = emb_decay * lr | ||
| 136 | |||
| 137 | if lambda_ != 0: | ||
| 138 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 139 | w[:].add_( | ||
| 140 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 141 | ) | ||
| 142 | |||
| 120 | def on_log(): | 143 | def on_log(): | 
| 121 | if ema_unet is not None: | 144 | if ema_unet is not None: | 
| 122 | return {"ema_decay": ema_unet.decay} | 145 | return {"ema_decay": ema_unet.decay} | 
| @@ -212,6 +235,7 @@ def dreambooth_prepare( | |||
| 212 | ]: | 235 | ]: | 
| 213 | layer.requires_grad_(False) | 236 | layer.requires_grad_(False) | 
| 214 | 237 | ||
| 238 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
| 215 | # text_encoder.text_model.embeddings.requires_grad_(False) | 239 | # text_encoder.text_model.embeddings.requires_grad_(False) | 
| 216 | 240 | ||
| 217 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 241 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 
