From 92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 11 Jan 2023 21:54:10 +0100 Subject: TI: Use grad clipping from LoRA #104 --- train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 0182693..73d9935 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -24,7 +24,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import run_model, generate_class_images +from training.common import loss_step, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args @@ -883,7 +883,7 @@ def main(): pass loop = partial( - run_model, + loss_step, vae, noise_scheduler, unet, -- cgit v1.2.3-54-g00ecf