summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 11:50:16 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 11:50:16 +0100
commit555912a86b012382a78f1b2717c2e0fde5994a04 (patch)
tree7569fa157ae63134febe569bc7a58933c2cf4b3c
parentUpdate (diff)
downloadtextual-inversion-diff-555912a86b012382a78f1b2717c2e0fde5994a04.tar.gz
textual-inversion-diff-555912a86b012382a78f1b2717c2e0fde5994a04.tar.bz2
textual-inversion-diff-555912a86b012382a78f1b2717c2e0fde5994a04.zip
Make embedding decay work like Adam decay
-rw-r--r--train_ti.py16
-rw-r--r--training/strategy/ti.py14
2 files changed, 9 insertions, 21 deletions
diff --git a/train_ti.py b/train_ti.py
index 0891c49..fc34d27 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -159,7 +159,7 @@ def parse_args():
159 parser.add_argument( 159 parser.add_argument(
160 "--tag_dropout", 160 "--tag_dropout",
161 type=float, 161 type=float,
162 default=0, 162 default=0.1,
163 help="Tag dropout probability.", 163 help="Tag dropout probability.",
164 ) 164 )
165 parser.add_argument( 165 parser.add_argument(
@@ -406,18 +406,12 @@ def parse_args():
406 help="Embedding decay target." 406 help="Embedding decay target."
407 ) 407 )
408 parser.add_argument( 408 parser.add_argument(
409 "--emb_decay_factor", 409 "--emb_decay",
410 default=1, 410 default=1e-1,
411 type=float, 411 type=float,
412 help="Embedding decay factor." 412 help="Embedding decay factor."
413 ) 413 )
414 parser.add_argument( 414 parser.add_argument(
415 "--emb_decay_start",
416 default=0,
417 type=float,
418 help="Embedding decay start offset."
419 )
420 parser.add_argument(
421 "--noise_timesteps", 415 "--noise_timesteps",
422 type=int, 416 type=int,
423 default=1000, 417 default=1000,
@@ -587,12 +581,10 @@ def main():
587 tokenizer=tokenizer, 581 tokenizer=tokenizer,
588 sample_scheduler=sample_scheduler, 582 sample_scheduler=sample_scheduler,
589 checkpoint_output_dir=checkpoint_output_dir, 583 checkpoint_output_dir=checkpoint_output_dir,
590 learning_rate=args.learning_rate,
591 gradient_checkpointing=args.gradient_checkpointing, 584 gradient_checkpointing=args.gradient_checkpointing,
592 use_emb_decay=args.use_emb_decay, 585 use_emb_decay=args.use_emb_decay,
593 emb_decay_target=args.emb_decay_target, 586 emb_decay_target=args.emb_decay_target,
594 emb_decay_factor=args.emb_decay_factor, 587 emb_decay=args.emb_decay,
595 emb_decay_start=args.emb_decay_start,
596 use_ema=args.use_ema, 588 use_ema=args.use_ema,
597 ema_inv_gamma=args.ema_inv_gamma, 589 ema_inv_gamma=args.ema_inv_gamma,
598 ema_power=args.ema_power, 590 ema_power=args.ema_power,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 081180f..eb6730b 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -32,12 +32,10 @@ def textual_inversion_strategy_callbacks(
32 seed: int, 32 seed: int,
33 placeholder_tokens: list[str], 33 placeholder_tokens: list[str],
34 placeholder_token_ids: list[list[int]], 34 placeholder_token_ids: list[list[int]],
35 learning_rate: float,
36 gradient_checkpointing: bool = False, 35 gradient_checkpointing: bool = False,
37 use_emb_decay: bool = False, 36 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4, 37 emb_decay_target: float = 0.4,
39 emb_decay_factor: float = 1, 38 emb_decay: float = 1e-2,
40 emb_decay_start: float = 0,
41 use_ema: bool = False, 39 use_ema: bool = False,
42 ema_inv_gamma: float = 1.0, 40 ema_inv_gamma: float = 1.0,
43 ema_power: int = 1, 41 ema_power: int = 1,
@@ -120,17 +118,15 @@ def textual_inversion_strategy_callbacks(
120 yield 118 yield
121 119
122 def on_after_optimize(lr: float): 120 def on_after_optimize(lr: float):
123 if ema_embeddings is not None:
124 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
125
126 @torch.no_grad()
127 def on_after_epoch(lr: float):
128 if use_emb_decay: 121 if use_emb_decay:
129 text_encoder.text_model.embeddings.normalize( 122 text_encoder.text_model.embeddings.normalize(
130 emb_decay_target, 123 emb_decay_target,
131 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) 124 min(1.0, emb_decay * lr)
132 ) 125 )
133 126
127 if ema_embeddings is not None:
128 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
129
134 def on_log(): 130 def on_log():
135 if ema_embeddings is not None: 131 if ema_embeddings is not None:
136 return {"ema_decay": ema_embeddings.decay} 132 return {"ema_decay": ema_embeddings.decay}