diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-21 14:08:49 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-21 14:08:49 +0100 |
| commit | 96638bbd54ca7f91d44c938fae7275d3ecaa6add (patch) | |
| tree | b281a0e58820151e8738dfc5294bde5be482956b /training/strategy/dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.gz textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.bz2 textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.zip | |
Fixed TI normalization order
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index fcf5c0d..0290327 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -115,14 +115,13 @@ def dreambooth_strategy_callbacks( | |||
| 115 | yield | 115 | yield |
| 116 | 116 | ||
| 117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
| 118 | if accelerator.sync_gradients: | 118 | params_to_clip = [unet.parameters()] |
| 119 | params_to_clip = [unet.parameters()] | 119 | if epoch < train_text_encoder_epochs: |
| 120 | if epoch < train_text_encoder_epochs: | 120 | params_to_clip.append(text_encoder.parameters()) |
| 121 | params_to_clip.append(text_encoder.parameters()) | 121 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
| 122 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | ||
| 123 | 122 | ||
| 124 | @torch.no_grad() | 123 | @torch.no_grad() |
| 125 | def on_after_optimize(lr: float): | 124 | def on_after_optimize(_, lr: float): |
| 126 | if ema_unet is not None: | 125 | if ema_unet is not None: |
| 127 | ema_unet.step(unet.parameters()) | 126 | ema_unet.step(unet.parameters()) |
| 128 | 127 | ||
