summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-21 14:08:49 +0100
committerVolpeon <git@volpeon.ink>2023-02-21 14:08:49 +0100
commit96638bbd54ca7f91d44c938fae7275d3ecaa6add (patch)
treeb281a0e58820151e8738dfc5294bde5be482956b /training/strategy/dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.gz
textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.bz2
textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.zip
Fixed TI normalization order
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index fcf5c0d..0290327 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -115,14 +115,13 @@ def dreambooth_strategy_callbacks(
115 yield 115 yield
116 116
117 def on_before_optimize(lr: float, epoch: int): 117 def on_before_optimize(lr: float, epoch: int):
118 if accelerator.sync_gradients: 118 params_to_clip = [unet.parameters()]
119 params_to_clip = [unet.parameters()] 119 if epoch < train_text_encoder_epochs:
120 if epoch < train_text_encoder_epochs: 120 params_to_clip.append(text_encoder.parameters())
121 params_to_clip.append(text_encoder.parameters()) 121 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
122 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
123 122
124 @torch.no_grad() 123 @torch.no_grad()
125 def on_after_optimize(lr: float): 124 def on_after_optimize(_, lr: float):
126 if ema_unet is not None: 125 if ema_unet is not None:
127 ema_unet.step(unet.parameters()) 126 ema_unet.step(unet.parameters())
128 127