From b5d3df18c3a56699a3658ad58a02d4494836972f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 21:46:15 +0200 Subject: Update --- training/functional.py | 13 ------------- training/strategy/dreambooth.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 13 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index f68faf9..3c7848f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -348,7 +348,6 @@ def loss_step( guidance_scale: float, prior_loss_weight: float, seed: int, - offset_noise_strength: float, input_pertubation: float, disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, @@ -377,16 +376,6 @@ def loss_step( ) applied_noise = noise - if offset_noise_strength != 0: - applied_noise = applied_noise + offset_noise_strength * perlin_noise( - latents.shape, - res=1, - octaves=4, - dtype=latents.dtype, - device=latents.device, - generator=generator, - ) - if input_pertubation != 0: applied_noise = applied_noise + input_pertubation * torch.randn( latents.shape, @@ -751,7 +740,6 @@ def train( global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, - offset_noise_strength: float = 0.01, input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, schedule_sampler: Optional[ScheduleSampler] = None, @@ -814,7 +802,6 @@ def train( guidance_scale, prior_loss_weight, seed, - offset_noise_strength, input_pertubation, disc, min_snr_gamma, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 88b441b..43fe838 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -1,4 +1,5 @@ from typing import Optional +from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path @@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + unet_.forward = MethodType(unet_.forward, unet_) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, @@ -185,6 +189,7 @@ def dreambooth_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + text_encoder_unfreeze_last_n_layers: int = 2, **kwargs ): ( @@ -198,6 +203,11 @@ def dreambooth_prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) + for layer in text_encoder.text_model.encoder.layers[ + : (-1 * text_encoder_unfreeze_last_n_layers) + ]: + layer.requires_grad_(False) + text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2