From b5d3df18c3a56699a3658ad58a02d4494836972f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 21:46:15 +0200 Subject: Update --- train_dreambooth.py | 9 +-------- train_lora.py | 8 -------- train_ti.py | 7 ------- training/functional.py | 13 ------------- training/strategy/dreambooth.py | 10 ++++++++++ 5 files changed, 11 insertions(+), 36 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 659b84c..0543a35 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -245,12 +245,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--offset_noise_strength", - type=float, - default=0, - help="Perlin offset noise strength.", - ) parser.add_argument( "--input_pertubation", type=float, @@ -496,7 +490,6 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss.", ) - parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( "--emb_dropout", @@ -679,6 +672,7 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] @@ -1074,7 +1068,6 @@ def main(): sample_output_dir=dreambooth_sample_output_dir, checkpoint_output_dir=dreambooth_checkpoint_output_dir, sample_frequency=dreambooth_sample_frequency, - offset_noise_strength=args.offset_noise_strength, input_pertubation=args.input_pertubation, no_val=args.valid_set_size == 0, avg_loss=avg_loss, diff --git a/train_lora.py b/train_lora.py index fccf48d..b7ee2d6 100644 --- a/train_lora.py +++ b/train_lora.py @@ -258,12 +258,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--offset_noise_strength", - type=float, - default=0, - help="Perlin offset noise strength.", - ) parser.add_argument( "--input_pertubation", type=float, @@ -1138,7 +1132,6 @@ def main(): sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=pti_sample_frequency, - offset_noise_strength=0, input_pertubation=args.input_pertubation, no_val=True, ) @@ -1291,7 +1284,6 @@ def main(): sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, - offset_noise_strength=args.offset_noise_strength, input_pertubation=args.input_pertubation, no_val=args.valid_set_size == 0, avg_loss=avg_loss, diff --git a/train_ti.py b/train_ti.py index c6f0b3a..da0c03e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -229,12 +229,6 @@ def parse_args(): choices=["all", "trailing", "leading", "between", "auto", "off"], help="Vector shuffling algorithm.", ) - parser.add_argument( - "--offset_noise_strength", - type=float, - default=0, - help="Offset noise strength.", - ) parser.add_argument( "--input_pertubation", type=float, @@ -876,7 +870,6 @@ def main(): checkpoint_frequency=args.checkpoint_frequency, milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, - offset_noise_strength=args.offset_noise_strength, input_pertubation=args.input_pertubation, # -- use_emb_decay=args.use_emb_decay, diff --git a/training/functional.py b/training/functional.py index f68faf9..3c7848f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -348,7 +348,6 @@ def loss_step( guidance_scale: float, prior_loss_weight: float, seed: int, - offset_noise_strength: float, input_pertubation: float, disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, @@ -377,16 +376,6 @@ def loss_step( ) applied_noise = noise - if offset_noise_strength != 0: - applied_noise = applied_noise + offset_noise_strength * perlin_noise( - latents.shape, - res=1, - octaves=4, - dtype=latents.dtype, - device=latents.device, - generator=generator, - ) - if input_pertubation != 0: applied_noise = applied_noise + input_pertubation * torch.randn( latents.shape, @@ -751,7 +740,6 @@ def train( global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, - offset_noise_strength: float = 0.01, input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, schedule_sampler: Optional[ScheduleSampler] = None, @@ -814,7 +802,6 @@ def train( guidance_scale, prior_loss_weight, seed, - offset_noise_strength, input_pertubation, disc, min_snr_gamma, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 88b441b..43fe838 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -1,4 +1,5 @@ from typing import Optional +from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path @@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + unet_.forward = MethodType(unet_.forward, unet_) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, @@ -185,6 +189,7 @@ def dreambooth_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + text_encoder_unfreeze_last_n_layers: int = 2, **kwargs ): ( @@ -198,6 +203,11 @@ def dreambooth_prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) + for layer in text_encoder.text_model.encoder.layers[ + : (-1 * text_encoder_unfreeze_last_n_layers) + ]: + layer.requires_grad_(False) + text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2