diff options
-rw-r--r-- | train_dreambooth.py | 32 | ||||
-rw-r--r-- | train_lora.py | 18 | ||||
-rw-r--r-- | train_ti.py | 20 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 26 |
4 files changed, 40 insertions, 56 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d284346..c8f03ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -145,12 +145,6 @@ def parse_args(): | |||
145 | help="Tokens to create an alias for.", | 145 | help="Tokens to create an alias for.", |
146 | ) | 146 | ) |
147 | parser.add_argument( | 147 | parser.add_argument( |
148 | "--inverted_initializer_tokens", | ||
149 | type=str, | ||
150 | nargs="*", | ||
151 | help="A token to use as initializer word.", | ||
152 | ) | ||
153 | parser.add_argument( | ||
154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 148 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
155 | ) | 149 | ) |
156 | parser.add_argument( | 150 | parser.add_argument( |
@@ -499,6 +493,15 @@ def parse_args(): | |||
499 | help="Embedding dropout probability.", | 493 | help="Embedding dropout probability.", |
500 | ) | 494 | ) |
501 | parser.add_argument( | 495 | parser.add_argument( |
496 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." | ||
497 | ) | ||
498 | parser.add_argument( | ||
499 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." | ||
500 | ) | ||
501 | parser.add_argument( | ||
502 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." | ||
503 | ) | ||
504 | parser.add_argument( | ||
502 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | 505 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
503 | ) | 506 | ) |
504 | parser.add_argument( | 507 | parser.add_argument( |
@@ -554,18 +557,6 @@ def parse_args(): | |||
554 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 557 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
555 | ) | 558 | ) |
556 | 559 | ||
557 | if isinstance(args.inverted_initializer_tokens, str): | ||
558 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
559 | args.placeholder_tokens | ||
560 | ) | ||
561 | |||
562 | if ( | ||
563 | isinstance(args.inverted_initializer_tokens, list) | ||
564 | and len(args.inverted_initializer_tokens) != 0 | ||
565 | ): | ||
566 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
567 | args.initializer_tokens += args.inverted_initializer_tokens | ||
568 | |||
569 | if isinstance(args.num_vectors, int): | 560 | if isinstance(args.num_vectors, int): |
570 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
571 | 562 | ||
@@ -875,6 +866,11 @@ def main(): | |||
875 | sample_num_batches=args.sample_batches, | 866 | sample_num_batches=args.sample_batches, |
876 | sample_num_steps=args.sample_steps, | 867 | sample_num_steps=args.sample_steps, |
877 | sample_image_size=args.sample_image_size, | 868 | sample_image_size=args.sample_image_size, |
869 | placeholder_tokens=placeholder_tokens, | ||
870 | placeholder_token_ids=placeholder_token_ids, | ||
871 | use_emb_decay=args.use_emb_decay, | ||
872 | emb_decay_target=args.emb_decay_target, | ||
873 | emb_decay=args.emb_decay, | ||
878 | max_grad_norm=args.max_grad_norm, | 874 | max_grad_norm=args.max_grad_norm, |
879 | ) | 875 | ) |
880 | 876 | ||
diff --git a/train_lora.py b/train_lora.py index 1ff25ff..fbec009 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -158,12 +158,6 @@ def parse_args(): | |||
158 | help="Tokens to create an alias for.", | 158 | help="Tokens to create an alias for.", |
159 | ) | 159 | ) |
160 | parser.add_argument( | 160 | parser.add_argument( |
161 | "--inverted_initializer_tokens", | ||
162 | type=str, | ||
163 | nargs="*", | ||
164 | help="A token to use as initializer word.", | ||
165 | ) | ||
166 | parser.add_argument( | ||
167 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 161 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
168 | ) | 162 | ) |
169 | parser.add_argument( | 163 | parser.add_argument( |
@@ -633,18 +627,6 @@ def parse_args(): | |||
633 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 627 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
634 | ) | 628 | ) |
635 | 629 | ||
636 | if isinstance(args.inverted_initializer_tokens, str): | ||
637 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
638 | args.placeholder_tokens | ||
639 | ) | ||
640 | |||
641 | if ( | ||
642 | isinstance(args.inverted_initializer_tokens, list) | ||
643 | and len(args.inverted_initializer_tokens) != 0 | ||
644 | ): | ||
645 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
646 | args.initializer_tokens += args.inverted_initializer_tokens | ||
647 | |||
648 | if isinstance(args.num_vectors, int): | 630 | if isinstance(args.num_vectors, int): |
649 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 631 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
650 | 632 | ||
diff --git a/train_ti.py b/train_ti.py index 1dbd637..8c63493 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -112,12 +112,6 @@ def parse_args(): | |||
112 | help="Tokens to create an alias for.", | 112 | help="Tokens to create an alias for.", |
113 | ) | 113 | ) |
114 | parser.add_argument( | 114 | parser.add_argument( |
115 | "--inverted_initializer_tokens", | ||
116 | type=str, | ||
117 | nargs="*", | ||
118 | help="A token to use as initializer word.", | ||
119 | ) | ||
120 | parser.add_argument( | ||
121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 115 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
122 | ) | 116 | ) |
123 | parser.add_argument( | 117 | parser.add_argument( |
@@ -545,18 +539,6 @@ def parse_args(): | |||
545 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 539 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
546 | ) | 540 | ) |
547 | 541 | ||
548 | if isinstance(args.inverted_initializer_tokens, str): | ||
549 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
550 | args.placeholder_tokens | ||
551 | ) | ||
552 | |||
553 | if ( | ||
554 | isinstance(args.inverted_initializer_tokens, list) | ||
555 | and len(args.inverted_initializer_tokens) != 0 | ||
556 | ): | ||
557 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
558 | args.initializer_tokens += args.inverted_initializer_tokens | ||
559 | |||
560 | if isinstance(args.num_vectors, int): | 542 | if isinstance(args.num_vectors, int): |
561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 543 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
562 | 544 | ||
@@ -872,7 +854,7 @@ def main(): | |||
872 | 854 | ||
873 | optimizer = create_optimizer( | 855 | optimizer = create_optimizer( |
874 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 856 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
875 | lr=learning_rate, | 857 | lr=args.learning_rate, |
876 | ) | 858 | ) |
877 | 859 | ||
878 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index dc19ba3..0f64747 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks( | |||
30 | sample_output_dir: Path, | 30 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | 31 | checkpoint_output_dir: Path, |
32 | seed: int, | 32 | seed: int, |
33 | placeholder_tokens: list[str], | ||
34 | placeholder_token_ids: list[list[int]], | ||
33 | train_text_encoder_cycles: int, | 35 | train_text_encoder_cycles: int, |
34 | text_encoder_unfreeze_last_n_layers: int = 2, | 36 | text_encoder_unfreeze_last_n_layers: int = 2, |
37 | use_emb_decay: bool = False, | ||
38 | emb_decay_target: float = 0.4, | ||
39 | emb_decay: float = 1e-2, | ||
35 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, |
36 | use_ema: bool = False, | 41 | use_ema: bool = False, |
37 | ema_inv_gamma: float = 1.0, | 42 | ema_inv_gamma: float = 1.0, |
@@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks( | |||
112 | params_to_clip.append(text_encoder.parameters()) | 117 | params_to_clip.append(text_encoder.parameters()) |
113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 118 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
114 | 119 | ||
120 | if len(placeholder_tokens) != 0 and use_emb_decay: | ||
121 | params = [ | ||
122 | p | ||
123 | for p in text_encoder.text_model.embeddings.parameters() | ||
124 | if p.grad is not None | ||
125 | ] | ||
126 | return torch.stack(params) if len(params) != 0 else None | ||
127 | |||
115 | @torch.no_grad() | 128 | @torch.no_grad() |
116 | def on_after_optimize(_, lrs: dict[str, float]): | 129 | def on_after_optimize(w, lrs: dict[str, float]): |
117 | if ema_unet is not None: | 130 | if ema_unet is not None: |
118 | ema_unet.step(unet.parameters()) | 131 | ema_unet.step(unet.parameters()) |
119 | 132 | ||
133 | if w is not None and "emb" in lrs: | ||
134 | lr = lrs["emb"] | ||
135 | lambda_ = emb_decay * lr | ||
136 | |||
137 | if lambda_ != 0: | ||
138 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
139 | w[:].add_( | ||
140 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
141 | ) | ||
142 | |||
120 | def on_log(): | 143 | def on_log(): |
121 | if ema_unet is not None: | 144 | if ema_unet is not None: |
122 | return {"ema_decay": ema_unet.decay} | 145 | return {"ema_decay": ema_unet.decay} |
@@ -212,6 +235,7 @@ def dreambooth_prepare( | |||
212 | ]: | 235 | ]: |
213 | layer.requires_grad_(False) | 236 | layer.requires_grad_(False) |
214 | 237 | ||
238 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
215 | # text_encoder.text_model.embeddings.requires_grad_(False) | 239 | # text_encoder.text_model.embeddings.requires_grad_(False) |
216 | 240 | ||
217 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 241 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |