From 92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 11 Jan 2023 21:54:10 +0100 Subject: TI: Use grad clipping from LoRA #104 --- environment.yaml | 2 +- train_dreambooth.py | 4 ++-- train_ti.py | 19 +++++++++++-------- training/common.py | 2 +- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/environment.yaml b/environment.yaml index eff69ed..9af40eb 100644 --- a/environment.yaml +++ b/environment.yaml @@ -23,4 +23,4 @@ dependencies: - test-tube>=0.7.5 - transformers==4.25.1 - triton==2.0.0.dev20221202 - - xformers==0.0.16rc401 + - xformers==0.0.16rc403 diff --git a/train_dreambooth.py b/train_dreambooth.py index 0182693..73d9935 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -24,7 +24,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import run_model, generate_class_images +from training.common import loss_step, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args @@ -883,7 +883,7 @@ def main(): pass loop = partial( - run_model, + loss_step, vae, noise_scheduler, unet, diff --git a/train_ti.py b/train_ti.py index 4e2c3c5..1054a5d 100644 --- a/train_ti.py +++ b/train_ti.py @@ -7,6 +7,7 @@ from pathlib import Path from contextlib import contextmanager, nullcontext import torch +import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator @@ -22,7 +23,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import run_model, generate_class_images +from training.common import loss_step, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args @@ -165,7 +166,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( @@ -866,14 +867,16 @@ def main(): finally: pass + @torch.no_grad() def on_clip(): - accelerator.clip_grad_norm_( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - args.max_grad_norm - ) + embeddings = text_encoder.text_model.embeddings.temp_token_embedding + + pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) + lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) + embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) loop = partial( - run_model, + loss_step, vae, noise_scheduler, unet, @@ -971,7 +974,7 @@ def main(): try: for epoch in range(num_epochs): - if accelerator.is_main_process: + if accelerator.is_main_process and False: if epoch % args.sample_frequency == 0: checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) diff --git a/training/common.py b/training/common.py index 67c2ab6..0b2ae44 100644 --- a/training/common.py +++ b/training/common.py @@ -58,7 +58,7 @@ def generate_class_images( torch.cuda.empty_cache() -def run_model( +def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, -- cgit v1.2.3-70-g09d2