diff options
author | Volpeon <git@volpeon.ink> | 2022-11-14 17:09:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-14 17:09:58 +0100 |
commit | 2ad46871e2ead985445da2848a4eb7072b6e48aa (patch) | |
tree | 3137923e2c00fe1d3cd37ddcc93c8a847b0c0762 /textual_inversion.py | |
parent | Update (diff) | |
download | textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.gz textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.bz2 textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.zip |
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 578c054..999161b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -15,14 +15,13 @@ import torch.utils.checkpoint | |||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | 24 | ||
25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
@@ -134,7 +133,7 @@ def parse_args(): | |||
134 | parser.add_argument( | 133 | parser.add_argument( |
135 | "--max_train_steps", | 134 | "--max_train_steps", |
136 | type=int, | 135 | type=int, |
137 | default=10000, | 136 | default=None, |
138 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 137 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
139 | ) | 138 | ) |
140 | parser.add_argument( | 139 | parser.add_argument( |
@@ -252,6 +251,12 @@ def parse_args(): | |||
252 | help="Number of samples to generate per batch", | 251 | help="Number of samples to generate per batch", |
253 | ) | 252 | ) |
254 | parser.add_argument( | 253 | parser.add_argument( |
254 | "--valid_set_size", | ||
255 | type=int, | ||
256 | default=None, | ||
257 | help="Number of images in the validation dataset." | ||
258 | ) | ||
259 | parser.add_argument( | ||
255 | "--train_batch_size", | 260 | "--train_batch_size", |
256 | type=int, | 261 | type=int, |
257 | default=1, | 262 | default=1, |
@@ -637,7 +642,7 @@ def main(): | |||
637 | size=args.resolution, | 642 | size=args.resolution, |
638 | repeats=args.repeats, | 643 | repeats=args.repeats, |
639 | center_crop=args.center_crop, | 644 | center_crop=args.center_crop, |
640 | valid_set_size=args.sample_batch_size*args.sample_batches, | 645 | valid_set_size=args.valid_set_size, |
641 | num_workers=args.dataloader_num_workers, | 646 | num_workers=args.dataloader_num_workers, |
642 | collate_fn=collate_fn | 647 | collate_fn=collate_fn |
643 | ) | 648 | ) |