summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py13
-rw-r--r--training/strategy/dreambooth.py10
2 files changed, 10 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py
index f68faf9..3c7848f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -348,7 +348,6 @@ def loss_step(
348 guidance_scale: float, 348 guidance_scale: float,
349 prior_loss_weight: float, 349 prior_loss_weight: float,
350 seed: int, 350 seed: int,
351 offset_noise_strength: float,
352 input_pertubation: float, 351 input_pertubation: float,
353 disc: Optional[ConvNeXtDiscriminator], 352 disc: Optional[ConvNeXtDiscriminator],
354 min_snr_gamma: int, 353 min_snr_gamma: int,
@@ -377,16 +376,6 @@ def loss_step(
377 ) 376 )
378 applied_noise = noise 377 applied_noise = noise
379 378
380 if offset_noise_strength != 0:
381 applied_noise = applied_noise + offset_noise_strength * perlin_noise(
382 latents.shape,
383 res=1,
384 octaves=4,
385 dtype=latents.dtype,
386 device=latents.device,
387 generator=generator,
388 )
389
390 if input_pertubation != 0: 379 if input_pertubation != 0:
391 applied_noise = applied_noise + input_pertubation * torch.randn( 380 applied_noise = applied_noise + input_pertubation * torch.randn(
392 latents.shape, 381 latents.shape,
@@ -751,7 +740,6 @@ def train(
751 global_step_offset: int = 0, 740 global_step_offset: int = 0,
752 guidance_scale: float = 0.0, 741 guidance_scale: float = 0.0,
753 prior_loss_weight: float = 1.0, 742 prior_loss_weight: float = 1.0,
754 offset_noise_strength: float = 0.01,
755 input_pertubation: float = 0.1, 743 input_pertubation: float = 0.1,
756 disc: Optional[ConvNeXtDiscriminator] = None, 744 disc: Optional[ConvNeXtDiscriminator] = None,
757 schedule_sampler: Optional[ScheduleSampler] = None, 745 schedule_sampler: Optional[ScheduleSampler] = None,
@@ -814,7 +802,6 @@ def train(
814 guidance_scale, 802 guidance_scale,
815 prior_loss_weight, 803 prior_loss_weight,
816 seed, 804 seed,
817 offset_noise_strength,
818 input_pertubation, 805 input_pertubation,
819 disc, 806 disc,
820 min_snr_gamma, 807 min_snr_gamma,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 88b441b..43fe838 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks(
130 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 131 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
131 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 132 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
132 133
134 unet_.forward = MethodType(unet_.forward, unet_)
135 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
136
133 with ema_context(): 137 with ema_context():
134 pipeline = VlpnStableDiffusion( 138 pipeline = VlpnStableDiffusion(
135 text_encoder=text_encoder_, 139 text_encoder=text_encoder_,
@@ -185,6 +189,7 @@ def dreambooth_prepare(
185 train_dataloader: DataLoader, 189 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 190 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 191 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
192 text_encoder_unfreeze_last_n_layers: int = 2,
188 **kwargs 193 **kwargs
189): 194):
190 ( 195 (
@@ -198,6 +203,11 @@ def dreambooth_prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 ) 204 )
200 205
206 for layer in text_encoder.text_model.encoder.layers[
207 : (-1 * text_encoder_unfreeze_last_n_layers)
208 ]:
209 layer.requires_grad_(False)
210
201 text_encoder.text_model.embeddings.requires_grad_(False) 211 text_encoder.text_model.embeddings.requires_grad_(False)
202 212
203 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 213 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler