diff options
| -rw-r--r-- | train_ti.py | 16 | ||||
| -rw-r--r-- | training/strategy/ti.py | 14 |
2 files changed, 9 insertions, 21 deletions
diff --git a/train_ti.py b/train_ti.py index 0891c49..fc34d27 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -159,7 +159,7 @@ def parse_args(): | |||
| 159 | parser.add_argument( | 159 | parser.add_argument( |
| 160 | "--tag_dropout", | 160 | "--tag_dropout", |
| 161 | type=float, | 161 | type=float, |
| 162 | default=0, | 162 | default=0.1, |
| 163 | help="Tag dropout probability.", | 163 | help="Tag dropout probability.", |
| 164 | ) | 164 | ) |
| 165 | parser.add_argument( | 165 | parser.add_argument( |
| @@ -406,18 +406,12 @@ def parse_args(): | |||
| 406 | help="Embedding decay target." | 406 | help="Embedding decay target." |
| 407 | ) | 407 | ) |
| 408 | parser.add_argument( | 408 | parser.add_argument( |
| 409 | "--emb_decay_factor", | 409 | "--emb_decay", |
| 410 | default=1, | 410 | default=1e-1, |
| 411 | type=float, | 411 | type=float, |
| 412 | help="Embedding decay factor." | 412 | help="Embedding decay factor." |
| 413 | ) | 413 | ) |
| 414 | parser.add_argument( | 414 | parser.add_argument( |
| 415 | "--emb_decay_start", | ||
| 416 | default=0, | ||
| 417 | type=float, | ||
| 418 | help="Embedding decay start offset." | ||
| 419 | ) | ||
| 420 | parser.add_argument( | ||
| 421 | "--noise_timesteps", | 415 | "--noise_timesteps", |
| 422 | type=int, | 416 | type=int, |
| 423 | default=1000, | 417 | default=1000, |
| @@ -587,12 +581,10 @@ def main(): | |||
| 587 | tokenizer=tokenizer, | 581 | tokenizer=tokenizer, |
| 588 | sample_scheduler=sample_scheduler, | 582 | sample_scheduler=sample_scheduler, |
| 589 | checkpoint_output_dir=checkpoint_output_dir, | 583 | checkpoint_output_dir=checkpoint_output_dir, |
| 590 | learning_rate=args.learning_rate, | ||
| 591 | gradient_checkpointing=args.gradient_checkpointing, | 584 | gradient_checkpointing=args.gradient_checkpointing, |
| 592 | use_emb_decay=args.use_emb_decay, | 585 | use_emb_decay=args.use_emb_decay, |
| 593 | emb_decay_target=args.emb_decay_target, | 586 | emb_decay_target=args.emb_decay_target, |
| 594 | emb_decay_factor=args.emb_decay_factor, | 587 | emb_decay=args.emb_decay, |
| 595 | emb_decay_start=args.emb_decay_start, | ||
| 596 | use_ema=args.use_ema, | 588 | use_ema=args.use_ema, |
| 597 | ema_inv_gamma=args.ema_inv_gamma, | 589 | ema_inv_gamma=args.ema_inv_gamma, |
| 598 | ema_power=args.ema_power, | 590 | ema_power=args.ema_power, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 081180f..eb6730b 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -32,12 +32,10 @@ def textual_inversion_strategy_callbacks( | |||
| 32 | seed: int, | 32 | seed: int, |
| 33 | placeholder_tokens: list[str], | 33 | placeholder_tokens: list[str], |
| 34 | placeholder_token_ids: list[list[int]], | 34 | placeholder_token_ids: list[list[int]], |
| 35 | learning_rate: float, | ||
| 36 | gradient_checkpointing: bool = False, | 35 | gradient_checkpointing: bool = False, |
| 37 | use_emb_decay: bool = False, | 36 | use_emb_decay: bool = False, |
| 38 | emb_decay_target: float = 0.4, | 37 | emb_decay_target: float = 0.4, |
| 39 | emb_decay_factor: float = 1, | 38 | emb_decay: float = 1e-2, |
| 40 | emb_decay_start: float = 0, | ||
| 41 | use_ema: bool = False, | 39 | use_ema: bool = False, |
| 42 | ema_inv_gamma: float = 1.0, | 40 | ema_inv_gamma: float = 1.0, |
| 43 | ema_power: int = 1, | 41 | ema_power: int = 1, |
| @@ -120,17 +118,15 @@ def textual_inversion_strategy_callbacks( | |||
| 120 | yield | 118 | yield |
| 121 | 119 | ||
| 122 | def on_after_optimize(lr: float): | 120 | def on_after_optimize(lr: float): |
| 123 | if ema_embeddings is not None: | ||
| 124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 125 | |||
| 126 | @torch.no_grad() | ||
| 127 | def on_after_epoch(lr: float): | ||
| 128 | if use_emb_decay: | 121 | if use_emb_decay: |
| 129 | text_encoder.text_model.embeddings.normalize( | 122 | text_encoder.text_model.embeddings.normalize( |
| 130 | emb_decay_target, | 123 | emb_decay_target, |
| 131 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) | 124 | min(1.0, emb_decay * lr) |
| 132 | ) | 125 | ) |
| 133 | 126 | ||
| 127 | if ema_embeddings is not None: | ||
| 128 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 129 | |||
| 134 | def on_log(): | 130 | def on_log(): |
| 135 | if ema_embeddings is not None: | 131 | if ema_embeddings is not None: |
| 136 | return {"ema_decay": ema_embeddings.decay} | 132 | return {"ema_decay": ema_embeddings.decay} |
