diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 7 | ||||
-rw-r--r-- | training/strategy/ti.py | 15 |
2 files changed, 16 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -362,6 +362,7 @@ def train_loop( | |||
362 | loss_step: LossCallable, | 362 | loss_step: LossCallable, |
363 | sample_frequency: int = 10, | 363 | sample_frequency: int = 10, |
364 | checkpoint_frequency: int = 50, | 364 | checkpoint_frequency: int = 50, |
365 | milestone_checkpoints: bool = True, | ||
365 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
366 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
367 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 368 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
@@ -514,7 +515,7 @@ def train_loop( | |||
514 | accelerator.log(logs, step=global_step) | 515 | accelerator.log(logs, step=global_step) |
515 | 516 | ||
516 | if accelerator.is_main_process: | 517 | if accelerator.is_main_process: |
517 | if avg_acc_val.avg.item() > best_acc_val: | 518 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: |
518 | local_progress_bar.clear() | 519 | local_progress_bar.clear() |
519 | global_progress_bar.clear() | 520 | global_progress_bar.clear() |
520 | 521 | ||
@@ -527,7 +528,7 @@ def train_loop( | |||
527 | accs.append(avg_acc_val.avg.item()) | 528 | accs.append(avg_acc_val.avg.item()) |
528 | else: | 529 | else: |
529 | if accelerator.is_main_process: | 530 | if accelerator.is_main_process: |
530 | if avg_acc.avg.item() > best_acc: | 531 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
531 | local_progress_bar.clear() | 532 | local_progress_bar.clear() |
532 | global_progress_bar.clear() | 533 | global_progress_bar.clear() |
533 | 534 | ||
@@ -572,6 +573,7 @@ def train( | |||
572 | num_train_epochs: int = 100, | 573 | num_train_epochs: int = 100, |
573 | sample_frequency: int = 20, | 574 | sample_frequency: int = 20, |
574 | checkpoint_frequency: int = 50, | 575 | checkpoint_frequency: int = 50, |
576 | milestone_checkpoints: bool = True, | ||
575 | global_step_offset: int = 0, | 577 | global_step_offset: int = 0, |
576 | with_prior_preservation: bool = False, | 578 | with_prior_preservation: bool = False, |
577 | prior_loss_weight: float = 1.0, | 579 | prior_loss_weight: float = 1.0, |
@@ -626,6 +628,7 @@ def train( | |||
626 | loss_step=loss_step_, | 628 | loss_step=loss_step_, |
627 | sample_frequency=sample_frequency, | 629 | sample_frequency=sample_frequency, |
628 | checkpoint_frequency=checkpoint_frequency, | 630 | checkpoint_frequency=checkpoint_frequency, |
631 | milestone_checkpoints=milestone_checkpoints, | ||
629 | global_step_offset=global_step_offset, | 632 | global_step_offset=global_step_offset, |
630 | num_epochs=num_train_epochs, | 633 | num_epochs=num_train_epochs, |
631 | callbacks=callbacks, | 634 | callbacks=callbacks, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
118 | if use_emb_decay: | 118 | if use_emb_decay: |
119 | text_encoder.text_model.embeddings.normalize( | 119 | lambda_ = emb_decay * lr |
120 | emb_decay_target, | 120 | |
121 | min(1.0, emb_decay * lr) | 121 | if lambda_ != 0: |
122 | ) | 122 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
123 | |||
124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | ||
125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
126 | mask[torch.all(w.grad == 0, dim=1)] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
123 | 130 | ||
124 | def on_after_optimize(lr: float): | 131 | def on_after_optimize(lr: float): |
125 | if ema_embeddings is not None: | 132 | if ema_embeddings is not None: |