summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0182693..73d9935 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -24,7 +24,7 @@ from slugify import slugify
24from util import load_config, load_embeddings_from_dir 24from util import load_config, load_embeddings_from_dir
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import VlpnDataModule, VlpnDataItem 26from data.csv import VlpnDataModule, VlpnDataItem
27from training.common import run_model, generate_class_images 27from training.common import loss_step, generate_class_images
28from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
29from training.lr import LRFinder 29from training.lr import LRFinder
30from training.util import AverageMeter, CheckpointerBase, save_args 30from training.util import AverageMeter, CheckpointerBase, save_args
@@ -883,7 +883,7 @@ def main():
883 pass 883 pass
884 884
885 loop = partial( 885 loop = partial(
886 run_model, 886 loss_step,
887 vae, 887 vae,
888 noise_scheduler, 888 noise_scheduler,
889 unet, 889 unet,