diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0182693..73d9935 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -24,7 +24,7 @@ from slugify import slugify | |||
| 24 | from util import load_config, load_embeddings_from_dir | 24 | from util import load_config, load_embeddings_from_dir |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from data.csv import VlpnDataModule, VlpnDataItem | 26 | from data.csv import VlpnDataModule, VlpnDataItem |
| 27 | from training.common import run_model, generate_class_images | 27 | from training.common import loss_step, generate_class_images |
| 28 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| 29 | from training.lr import LRFinder | 29 | from training.lr import LRFinder |
| 30 | from training.util import AverageMeter, CheckpointerBase, save_args | 30 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -883,7 +883,7 @@ def main(): | |||
| 883 | pass | 883 | pass |
| 884 | 884 | ||
| 885 | loop = partial( | 885 | loop = partial( |
| 886 | run_model, | 886 | loss_step, |
| 887 | vae, | 887 | vae, |
| 888 | noise_scheduler, | 888 | noise_scheduler, |
| 889 | unet, | 889 | unet, |
