summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-11 21:54:10 +0100
committerVolpeon <git@volpeon.ink>2023-01-11 21:54:10 +0100
commit92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d (patch)
tree7c032c88126b15e2a9eb13ccc4a8293e8d660f29 /train_dreambooth.py
parentBetter defaults (diff)
downloadtextual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.tar.gz
textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.tar.bz2
textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.zip
TI: Use grad clipping from LoRA #104
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0182693..73d9935 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -24,7 +24,7 @@ from slugify import slugify
24from util import load_config, load_embeddings_from_dir 24from util import load_config, load_embeddings_from_dir
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import VlpnDataModule, VlpnDataItem 26from data.csv import VlpnDataModule, VlpnDataItem
27from training.common import run_model, generate_class_images 27from training.common import loss_step, generate_class_images
28from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
29from training.lr import LRFinder 29from training.lr import LRFinder
30from training.util import AverageMeter, CheckpointerBase, save_args 30from training.util import AverageMeter, CheckpointerBase, save_args
@@ -883,7 +883,7 @@ def main():
883 pass 883 pass
884 884
885 loop = partial( 885 loop = partial(
886 run_model, 886 loss_step,
887 vae, 887 vae,
888 noise_scheduler, 888 noise_scheduler,
889 unet, 889 unet,