diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0182693..73d9935 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -24,7 +24,7 @@ from slugify import slugify | |||
24 | from util import load_config, load_embeddings_from_dir | 24 | from util import load_config, load_embeddings_from_dir |
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from data.csv import VlpnDataModule, VlpnDataItem | 26 | from data.csv import VlpnDataModule, VlpnDataItem |
27 | from training.common import run_model, generate_class_images | 27 | from training.common import loss_step, generate_class_images |
28 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
29 | from training.lr import LRFinder | 29 | from training.lr import LRFinder |
30 | from training.util import AverageMeter, CheckpointerBase, save_args | 30 | from training.util import AverageMeter, CheckpointerBase, save_args |
@@ -883,7 +883,7 @@ def main(): | |||
883 | pass | 883 | pass |
884 | 884 | ||
885 | loop = partial( | 885 | loop = partial( |
886 | run_model, | 886 | loss_step, |
887 | vae, | 887 | vae, |
888 | noise_scheduler, | 888 | noise_scheduler, |
889 | unet, | 889 | unet, |