summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py7
-rw-r--r--training/strategy/ti.py15
2 files changed, 16 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index 85dd884..739d055 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -362,6 +362,7 @@ def train_loop(
362 loss_step: LossCallable, 362 loss_step: LossCallable,
363 sample_frequency: int = 10, 363 sample_frequency: int = 10,
364 checkpoint_frequency: int = 50, 364 checkpoint_frequency: int = 50,
365 milestone_checkpoints: bool = True,
365 global_step_offset: int = 0, 366 global_step_offset: int = 0,
366 num_epochs: int = 100, 367 num_epochs: int = 100,
367 callbacks: TrainingCallbacks = TrainingCallbacks(), 368 callbacks: TrainingCallbacks = TrainingCallbacks(),
@@ -514,7 +515,7 @@ def train_loop(
514 accelerator.log(logs, step=global_step) 515 accelerator.log(logs, step=global_step)
515 516
516 if accelerator.is_main_process: 517 if accelerator.is_main_process:
517 if avg_acc_val.avg.item() > best_acc_val: 518 if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints:
518 local_progress_bar.clear() 519 local_progress_bar.clear()
519 global_progress_bar.clear() 520 global_progress_bar.clear()
520 521
@@ -527,7 +528,7 @@ def train_loop(
527 accs.append(avg_acc_val.avg.item()) 528 accs.append(avg_acc_val.avg.item())
528 else: 529 else:
529 if accelerator.is_main_process: 530 if accelerator.is_main_process:
530 if avg_acc.avg.item() > best_acc: 531 if avg_acc.avg.item() > best_acc and milestone_checkpoints:
531 local_progress_bar.clear() 532 local_progress_bar.clear()
532 global_progress_bar.clear() 533 global_progress_bar.clear()
533 534
@@ -572,6 +573,7 @@ def train(
572 num_train_epochs: int = 100, 573 num_train_epochs: int = 100,
573 sample_frequency: int = 20, 574 sample_frequency: int = 20,
574 checkpoint_frequency: int = 50, 575 checkpoint_frequency: int = 50,
576 milestone_checkpoints: bool = True,
575 global_step_offset: int = 0, 577 global_step_offset: int = 0,
576 with_prior_preservation: bool = False, 578 with_prior_preservation: bool = False,
577 prior_loss_weight: float = 1.0, 579 prior_loss_weight: float = 1.0,
@@ -626,6 +628,7 @@ def train(
626 loss_step=loss_step_, 628 loss_step=loss_step_,
627 sample_frequency=sample_frequency, 629 sample_frequency=sample_frequency,
628 checkpoint_frequency=checkpoint_frequency, 630 checkpoint_frequency=checkpoint_frequency,
631 milestone_checkpoints=milestone_checkpoints,
629 global_step_offset=global_step_offset, 632 global_step_offset=global_step_offset,
630 num_epochs=num_train_epochs, 633 num_epochs=num_train_epochs,
631 callbacks=callbacks, 634 callbacks=callbacks,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 66d3129..09beec4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_before_optimize(lr: float, epoch: int): 117 def on_before_optimize(lr: float, epoch: int):
118 if use_emb_decay: 118 if use_emb_decay:
119 text_encoder.text_model.embeddings.normalize( 119 lambda_ = emb_decay * lr
120 emb_decay_target, 120
121 min(1.0, emb_decay * lr) 121 if lambda_ != 0:
122 ) 122 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
123
124 mask = torch.zeros(w.size(0), dtype=torch.bool)
125 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
126 mask[torch.all(w.grad == 0, dim=1)] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
123 130
124 def on_after_optimize(lr: float): 131 def on_after_optimize(lr: float):
125 if ema_embeddings is not None: 132 if ema_embeddings is not None: