diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
commit | 83808fe00ac891ad2f625388d144c318b2cb5bfe (patch) | |
tree | b7ca19d27f90be6f02b14f4a39c62fc7250041a2 /train_dreambooth.py | |
parent | TI: Prepare UNet with Accelerate as well (diff) | |
download | textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2 textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip |
WIP: Modularization ("free(): invalid pointer" my ass)
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 53776ba..71bad7e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -20,10 +20,9 @@ from slugify import slugify | |||
20 | from util import load_config, load_embeddings_from_dir | 20 | from util import load_config, load_embeddings_from_dir |
21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
22 | from data.csv import VlpnDataModule, VlpnDataItem | 22 | from data.csv import VlpnDataModule, VlpnDataItem |
23 | from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | ||
24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
25 | from training.lr import LRFinder | 24 | from training.lr import LRFinder |
26 | from training.util import CheckpointerBase, EMAModel, save_args | 25 | from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models |
27 | from models.clip.tokenizer import MultiCLIPTokenizer | 26 | from models.clip.tokenizer import MultiCLIPTokenizer |
28 | 27 | ||
29 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |