summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/functional.py8
-rw-r--r--training/strategy/dreambooth.py11
-rw-r--r--training/strategy/ti.py15
3 files changed, 19 insertions, 15 deletions
diff --git a/training/functional.py b/training/functional.py
index e7c4320..b830261 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -38,8 +38,8 @@ class TrainingCallbacks():
38 on_accum_model: Callable[[], torch.nn.Module] = const(None) 38 on_accum_model: Callable[[], torch.nn.Module] = const(None)
39 on_log: Callable[[], dict[str, Any]] = const({}) 39 on_log: Callable[[], dict[str, Any]] = const({})
40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
41 on_before_optimize: Callable[[float, int], None] = const() 41 on_before_optimize: Callable[[float, int], Any] = const()
42 on_after_optimize: Callable[[float], None] = const() 42 on_after_optimize: Callable[[Any, float], None] = const()
43 on_after_epoch: Callable[[float], None] = const() 43 on_after_epoch: Callable[[float], None] = const()
44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
45 on_sample: Callable[[int], None] = const() 45 on_sample: Callable[[int], None] = const()
@@ -455,13 +455,13 @@ def train_loop(
455 local_progress_bar.set_postfix(**logs) 455 local_progress_bar.set_postfix(**logs)
456 456
457 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 457 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
458 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 458 before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
459 459
460 optimizer.step() 460 optimizer.step()
461 lr_scheduler.step() 461 lr_scheduler.step()
462 optimizer.zero_grad(set_to_none=True) 462 optimizer.zero_grad(set_to_none=True)
463 463
464 on_after_optimize(lr_scheduler.get_last_lr()[0]) 464 on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0])
465 465
466 local_progress_bar.update(1) 466 local_progress_bar.update(1)
467 global_progress_bar.update(1) 467 global_progress_bar.update(1)
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index fcf5c0d..0290327 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -115,14 +115,13 @@ def dreambooth_strategy_callbacks(
115 yield 115 yield
116 116
117 def on_before_optimize(lr: float, epoch: int): 117 def on_before_optimize(lr: float, epoch: int):
118 if accelerator.sync_gradients: 118 params_to_clip = [unet.parameters()]
119 params_to_clip = [unet.parameters()] 119 if epoch < train_text_encoder_epochs:
120 if epoch < train_text_encoder_epochs: 120 params_to_clip.append(text_encoder.parameters())
121 params_to_clip.append(text_encoder.parameters()) 121 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
122 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
123 122
124 @torch.no_grad() 123 @torch.no_grad()
125 def on_after_optimize(lr: float): 124 def on_after_optimize(_, lr: float):
126 if ema_unet is not None: 125 if ema_unet is not None:
127 ema_unet.step(unet.parameters()) 126 ema_unet.step(unet.parameters())
128 127
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 09beec4..732cd74 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -116,6 +116,15 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_before_optimize(lr: float, epoch: int): 117 def on_before_optimize(lr: float, epoch: int):
118 if use_emb_decay: 118 if use_emb_decay:
119 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
120 return torch.all(w.grad == 0, dim=1)
121
122 @torch.no_grad()
123 def on_after_optimize(zero_ids, lr: float):
124 if ema_embeddings is not None:
125 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
126
127 if use_emb_decay:
119 lambda_ = emb_decay * lr 128 lambda_ = emb_decay * lr
120 129
121 if lambda_ != 0: 130 if lambda_ != 0:
@@ -123,15 +132,11 @@ def textual_inversion_strategy_callbacks(
123 132
124 mask = torch.zeros(w.size(0), dtype=torch.bool) 133 mask = torch.zeros(w.size(0), dtype=torch.bool)
125 mask[text_encoder.text_model.embeddings.temp_token_ids] = True 134 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
126 mask[torch.all(w.grad == 0, dim=1)] = False 135 mask[zero_ids] = False
127 136
128 norm = w[mask, :].norm(dim=-1, keepdim=True) 137 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 138 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130 139
131 def on_after_optimize(lr: float):
132 if ema_embeddings is not None:
133 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
134
135 def on_log(): 140 def on_log():
136 if ema_embeddings is not None: 141 if ema_embeddings is not None:
137 return {"ema_decay": ema_embeddings.decay} 142 return {"ema_decay": ema_embeddings.decay}