From 96638bbd54ca7f91d44c938fae7275d3ecaa6add Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 14:08:49 +0100 Subject: Fixed TI normalization order --- training/strategy/dreambooth.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index fcf5c0d..0290327 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -115,14 +115,13 @@ def dreambooth_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - if accelerator.sync_gradients: - params_to_clip = [unet.parameters()] - if epoch < train_text_encoder_epochs: - params_to_clip.append(text_encoder.parameters()) - accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) + params_to_clip = [unet.parameters()] + if epoch < train_text_encoder_epochs: + params_to_clip.append(text_encoder.parameters()) + accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @torch.no_grad() - def on_after_optimize(lr: float): + def on_after_optimize(_, lr: float): if ema_unet is not None: ema_unet.step(unet.parameters()) -- cgit v1.2.3-54-g00ecf