diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-21 09:09:50 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-21 09:09:50 +0100 |
| commit | 16b92605a59d59c65789c89b54bb97da51908056 (patch) | |
| tree | b0cbf8677897c3f44c736b710fd034eb2c5de6a0 /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.gz textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.bz2 textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.zip | |
Embedding normalization: Ignore tensors with grad = 0
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 7 | ||||
| -rw-r--r-- | training/strategy/ti.py | 15 |
2 files changed, 16 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -362,6 +362,7 @@ def train_loop( | |||
| 362 | loss_step: LossCallable, | 362 | loss_step: LossCallable, |
| 363 | sample_frequency: int = 10, | 363 | sample_frequency: int = 10, |
| 364 | checkpoint_frequency: int = 50, | 364 | checkpoint_frequency: int = 50, |
| 365 | milestone_checkpoints: bool = True, | ||
| 365 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
| 366 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
| 367 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 368 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| @@ -514,7 +515,7 @@ def train_loop( | |||
| 514 | accelerator.log(logs, step=global_step) | 515 | accelerator.log(logs, step=global_step) |
| 515 | 516 | ||
| 516 | if accelerator.is_main_process: | 517 | if accelerator.is_main_process: |
| 517 | if avg_acc_val.avg.item() > best_acc_val: | 518 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: |
| 518 | local_progress_bar.clear() | 519 | local_progress_bar.clear() |
| 519 | global_progress_bar.clear() | 520 | global_progress_bar.clear() |
| 520 | 521 | ||
| @@ -527,7 +528,7 @@ def train_loop( | |||
| 527 | accs.append(avg_acc_val.avg.item()) | 528 | accs.append(avg_acc_val.avg.item()) |
| 528 | else: | 529 | else: |
| 529 | if accelerator.is_main_process: | 530 | if accelerator.is_main_process: |
| 530 | if avg_acc.avg.item() > best_acc: | 531 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
| 531 | local_progress_bar.clear() | 532 | local_progress_bar.clear() |
| 532 | global_progress_bar.clear() | 533 | global_progress_bar.clear() |
| 533 | 534 | ||
| @@ -572,6 +573,7 @@ def train( | |||
| 572 | num_train_epochs: int = 100, | 573 | num_train_epochs: int = 100, |
| 573 | sample_frequency: int = 20, | 574 | sample_frequency: int = 20, |
| 574 | checkpoint_frequency: int = 50, | 575 | checkpoint_frequency: int = 50, |
| 576 | milestone_checkpoints: bool = True, | ||
| 575 | global_step_offset: int = 0, | 577 | global_step_offset: int = 0, |
| 576 | with_prior_preservation: bool = False, | 578 | with_prior_preservation: bool = False, |
| 577 | prior_loss_weight: float = 1.0, | 579 | prior_loss_weight: float = 1.0, |
| @@ -626,6 +628,7 @@ def train( | |||
| 626 | loss_step=loss_step_, | 628 | loss_step=loss_step_, |
| 627 | sample_frequency=sample_frequency, | 629 | sample_frequency=sample_frequency, |
| 628 | checkpoint_frequency=checkpoint_frequency, | 630 | checkpoint_frequency=checkpoint_frequency, |
| 631 | milestone_checkpoints=milestone_checkpoints, | ||
| 629 | global_step_offset=global_step_offset, | 632 | global_step_offset=global_step_offset, |
| 630 | num_epochs=num_train_epochs, | 633 | num_epochs=num_train_epochs, |
| 631 | callbacks=callbacks, | 634 | callbacks=callbacks, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
| 118 | if use_emb_decay: | 118 | if use_emb_decay: |
| 119 | text_encoder.text_model.embeddings.normalize( | 119 | lambda_ = emb_decay * lr |
| 120 | emb_decay_target, | 120 | |
| 121 | min(1.0, emb_decay * lr) | 121 | if lambda_ != 0: |
| 122 | ) | 122 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
| 123 | |||
| 124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | ||
| 125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
| 126 | mask[torch.all(w.grad == 0, dim=1)] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 123 | 130 | ||
| 124 | def on_after_optimize(lr: float): | 131 | def on_after_optimize(lr: float): |
| 125 | if ema_embeddings is not None: | 132 | if ema_embeddings is not None: |
