diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 8 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 11 | ||||
| -rw-r--r-- | training/strategy/ti.py | 15 |
3 files changed, 19 insertions, 15 deletions
diff --git a/training/functional.py b/training/functional.py index e7c4320..b830261 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -38,8 +38,8 @@ class TrainingCallbacks(): | |||
| 38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | 38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
| 39 | on_log: Callable[[], dict[str, Any]] = const({}) | 39 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 41 | on_before_optimize: Callable[[float, int], None] = const() | 41 | on_before_optimize: Callable[[float, int], Any] = const() |
| 42 | on_after_optimize: Callable[[float], None] = const() | 42 | on_after_optimize: Callable[[Any, float], None] = const() |
| 43 | on_after_epoch: Callable[[float], None] = const() | 43 | on_after_epoch: Callable[[float], None] = const() |
| 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| 45 | on_sample: Callable[[int], None] = const() | 45 | on_sample: Callable[[int], None] = const() |
| @@ -455,13 +455,13 @@ def train_loop( | |||
| 455 | local_progress_bar.set_postfix(**logs) | 455 | local_progress_bar.set_postfix(**logs) |
| 456 | 456 | ||
| 457 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 457 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
| 458 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 458 | before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
| 459 | 459 | ||
| 460 | optimizer.step() | 460 | optimizer.step() |
| 461 | lr_scheduler.step() | 461 | lr_scheduler.step() |
| 462 | optimizer.zero_grad(set_to_none=True) | 462 | optimizer.zero_grad(set_to_none=True) |
| 463 | 463 | ||
| 464 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 464 | on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) |
| 465 | 465 | ||
| 466 | local_progress_bar.update(1) | 466 | local_progress_bar.update(1) |
| 467 | global_progress_bar.update(1) | 467 | global_progress_bar.update(1) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index fcf5c0d..0290327 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -115,14 +115,13 @@ def dreambooth_strategy_callbacks( | |||
| 115 | yield | 115 | yield |
| 116 | 116 | ||
| 117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
| 118 | if accelerator.sync_gradients: | 118 | params_to_clip = [unet.parameters()] |
| 119 | params_to_clip = [unet.parameters()] | 119 | if epoch < train_text_encoder_epochs: |
| 120 | if epoch < train_text_encoder_epochs: | 120 | params_to_clip.append(text_encoder.parameters()) |
| 121 | params_to_clip.append(text_encoder.parameters()) | 121 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
| 122 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | ||
| 123 | 122 | ||
| 124 | @torch.no_grad() | 123 | @torch.no_grad() |
| 125 | def on_after_optimize(lr: float): | 124 | def on_after_optimize(_, lr: float): |
| 126 | if ema_unet is not None: | 125 | if ema_unet is not None: |
| 127 | ema_unet.step(unet.parameters()) | 126 | ema_unet.step(unet.parameters()) |
| 128 | 127 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 09beec4..732cd74 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -116,6 +116,15 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
| 118 | if use_emb_decay: | 118 | if use_emb_decay: |
| 119 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 120 | return torch.all(w.grad == 0, dim=1) | ||
| 121 | |||
| 122 | @torch.no_grad() | ||
| 123 | def on_after_optimize(zero_ids, lr: float): | ||
| 124 | if ema_embeddings is not None: | ||
| 125 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 126 | |||
| 127 | if use_emb_decay: | ||
| 119 | lambda_ = emb_decay * lr | 128 | lambda_ = emb_decay * lr |
| 120 | 129 | ||
| 121 | if lambda_ != 0: | 130 | if lambda_ != 0: |
| @@ -123,15 +132,11 @@ def textual_inversion_strategy_callbacks( | |||
| 123 | 132 | ||
| 124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.size(0), dtype=torch.bool) |
| 125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True |
| 126 | mask[torch.all(w.grad == 0, dim=1)] = False | 135 | mask[zero_ids] = False |
| 127 | 136 | ||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | 137 | norm = w[mask, :].norm(dim=-1, keepdim=True) |
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 138 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) |
| 130 | 139 | ||
| 131 | def on_after_optimize(lr: float): | ||
| 132 | if ema_embeddings is not None: | ||
| 133 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 134 | |||
| 135 | def on_log(): | 140 | def on_log(): |
| 136 | if ema_embeddings is not None: | 141 | if ema_embeddings is not None: |
| 137 | return {"ema_decay": ema_embeddings.decay} | 142 | return {"ema_decay": ema_embeddings.decay} |
