summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py32
-rw-r--r--train_lora.py18
-rw-r--r--train_ti.py20
-rw-r--r--training/strategy/dreambooth.py26
4 files changed, 40 insertions, 56 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d284346..c8f03ea 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -145,12 +145,6 @@ def parse_args():
145 help="Tokens to create an alias for.", 145 help="Tokens to create an alias for.",
146 ) 146 )
147 parser.add_argument( 147 parser.add_argument(
148 "--inverted_initializer_tokens",
149 type=str,
150 nargs="*",
151 help="A token to use as initializer word.",
152 )
153 parser.add_argument(
154 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." 148 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
155 ) 149 )
156 parser.add_argument( 150 parser.add_argument(
@@ -499,6 +493,15 @@ def parse_args():
499 help="Embedding dropout probability.", 493 help="Embedding dropout probability.",
500 ) 494 )
501 parser.add_argument( 495 parser.add_argument(
496 "--use_emb_decay", action="store_true", help="Whether to use embedding decay."
497 )
498 parser.add_argument(
499 "--emb_decay_target", default=0.4, type=float, help="Embedding decay target."
500 )
501 parser.add_argument(
502 "--emb_decay", default=1e2, type=float, help="Embedding decay factor."
503 )
504 parser.add_argument(
502 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 505 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
503 ) 506 )
504 parser.add_argument( 507 parser.add_argument(
@@ -554,18 +557,6 @@ def parse_args():
554 "--placeholder_tokens and --initializer_tokens must have the same number of items" 557 "--placeholder_tokens and --initializer_tokens must have the same number of items"
555 ) 558 )
556 559
557 if isinstance(args.inverted_initializer_tokens, str):
558 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
559 args.placeholder_tokens
560 )
561
562 if (
563 isinstance(args.inverted_initializer_tokens, list)
564 and len(args.inverted_initializer_tokens) != 0
565 ):
566 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
567 args.initializer_tokens += args.inverted_initializer_tokens
568
569 if isinstance(args.num_vectors, int): 560 if isinstance(args.num_vectors, int):
570 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 561 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
571 562
@@ -875,6 +866,11 @@ def main():
875 sample_num_batches=args.sample_batches, 866 sample_num_batches=args.sample_batches,
876 sample_num_steps=args.sample_steps, 867 sample_num_steps=args.sample_steps,
877 sample_image_size=args.sample_image_size, 868 sample_image_size=args.sample_image_size,
869 placeholder_tokens=placeholder_tokens,
870 placeholder_token_ids=placeholder_token_ids,
871 use_emb_decay=args.use_emb_decay,
872 emb_decay_target=args.emb_decay_target,
873 emb_decay=args.emb_decay,
878 max_grad_norm=args.max_grad_norm, 874 max_grad_norm=args.max_grad_norm,
879 ) 875 )
880 876
diff --git a/train_lora.py b/train_lora.py
index 1ff25ff..fbec009 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -158,12 +158,6 @@ def parse_args():
158 help="Tokens to create an alias for.", 158 help="Tokens to create an alias for.",
159 ) 159 )
160 parser.add_argument( 160 parser.add_argument(
161 "--inverted_initializer_tokens",
162 type=str,
163 nargs="*",
164 help="A token to use as initializer word.",
165 )
166 parser.add_argument(
167 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." 161 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
168 ) 162 )
169 parser.add_argument( 163 parser.add_argument(
@@ -633,18 +627,6 @@ def parse_args():
633 "--placeholder_tokens and --initializer_tokens must have the same number of items" 627 "--placeholder_tokens and --initializer_tokens must have the same number of items"
634 ) 628 )
635 629
636 if isinstance(args.inverted_initializer_tokens, str):
637 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
638 args.placeholder_tokens
639 )
640
641 if (
642 isinstance(args.inverted_initializer_tokens, list)
643 and len(args.inverted_initializer_tokens) != 0
644 ):
645 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
646 args.initializer_tokens += args.inverted_initializer_tokens
647
648 if isinstance(args.num_vectors, int): 630 if isinstance(args.num_vectors, int):
649 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 631 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
650 632
diff --git a/train_ti.py b/train_ti.py
index 1dbd637..8c63493 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -112,12 +112,6 @@ def parse_args():
112 help="Tokens to create an alias for.", 112 help="Tokens to create an alias for.",
113 ) 113 )
114 parser.add_argument( 114 parser.add_argument(
115 "--inverted_initializer_tokens",
116 type=str,
117 nargs="*",
118 help="A token to use as initializer word.",
119 )
120 parser.add_argument(
121 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." 115 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
122 ) 116 )
123 parser.add_argument( 117 parser.add_argument(
@@ -545,18 +539,6 @@ def parse_args():
545 "--placeholder_tokens and --initializer_tokens must have the same number of items" 539 "--placeholder_tokens and --initializer_tokens must have the same number of items"
546 ) 540 )
547 541
548 if isinstance(args.inverted_initializer_tokens, str):
549 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
550 args.placeholder_tokens
551 )
552
553 if (
554 isinstance(args.inverted_initializer_tokens, list)
555 and len(args.inverted_initializer_tokens) != 0
556 ):
557 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
558 args.initializer_tokens += args.inverted_initializer_tokens
559
560 if isinstance(args.num_vectors, int): 542 if isinstance(args.num_vectors, int):
561 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 543 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
562 544
@@ -872,7 +854,7 @@ def main():
872 854
873 optimizer = create_optimizer( 855 optimizer = create_optimizer(
874 text_encoder.text_model.embeddings.token_embedding.parameters(), 856 text_encoder.text_model.embeddings.token_embedding.parameters(),
875 lr=learning_rate, 857 lr=args.learning_rate,
876 ) 858 )
877 859
878 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 860 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index dc19ba3..0f64747 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks(
30 sample_output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path, 31 checkpoint_output_dir: Path,
32 seed: int, 32 seed: int,
33 placeholder_tokens: list[str],
34 placeholder_token_ids: list[list[int]],
33 train_text_encoder_cycles: int, 35 train_text_encoder_cycles: int,
34 text_encoder_unfreeze_last_n_layers: int = 2, 36 text_encoder_unfreeze_last_n_layers: int = 2,
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay: float = 1e-2,
35 max_grad_norm: float = 1.0, 40 max_grad_norm: float = 1.0,
36 use_ema: bool = False, 41 use_ema: bool = False,
37 ema_inv_gamma: float = 1.0, 42 ema_inv_gamma: float = 1.0,
@@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks(
112 params_to_clip.append(text_encoder.parameters()) 117 params_to_clip.append(text_encoder.parameters())
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 118 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 119
120 if len(placeholder_tokens) != 0 and use_emb_decay:
121 params = [
122 p
123 for p in text_encoder.text_model.embeddings.parameters()
124 if p.grad is not None
125 ]
126 return torch.stack(params) if len(params) != 0 else None
127
115 @torch.no_grad() 128 @torch.no_grad()
116 def on_after_optimize(_, lrs: dict[str, float]): 129 def on_after_optimize(w, lrs: dict[str, float]):
117 if ema_unet is not None: 130 if ema_unet is not None:
118 ema_unet.step(unet.parameters()) 131 ema_unet.step(unet.parameters())
119 132
133 if w is not None and "emb" in lrs:
134 lr = lrs["emb"]
135 lambda_ = emb_decay * lr
136
137 if lambda_ != 0:
138 norm = w[:, :].norm(dim=-1, keepdim=True)
139 w[:].add_(
140 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
141 )
142
120 def on_log(): 143 def on_log():
121 if ema_unet is not None: 144 if ema_unet is not None:
122 return {"ema_decay": ema_unet.decay} 145 return {"ema_decay": ema_unet.decay}
@@ -212,6 +235,7 @@ def dreambooth_prepare(
212 ]: 235 ]:
213 layer.requires_grad_(False) 236 layer.requires_grad_(False)
214 237
238 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
215 # text_encoder.text_model.embeddings.requires_grad_(False) 239 # text_encoder.text_model.embeddings.requires_grad_(False)
216 240
217 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 241 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler