diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 578c054..999161b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -15,14 +15,13 @@ import torch.utils.checkpoint | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 20 | from PIL import Image | 20 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | 24 | ||
| 25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
| 28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| @@ -134,7 +133,7 @@ def parse_args(): | |||
| 134 | parser.add_argument( | 133 | parser.add_argument( |
| 135 | "--max_train_steps", | 134 | "--max_train_steps", |
| 136 | type=int, | 135 | type=int, |
| 137 | default=10000, | 136 | default=None, |
| 138 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 137 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 139 | ) | 138 | ) |
| 140 | parser.add_argument( | 139 | parser.add_argument( |
| @@ -252,6 +251,12 @@ def parse_args(): | |||
| 252 | help="Number of samples to generate per batch", | 251 | help="Number of samples to generate per batch", |
| 253 | ) | 252 | ) |
| 254 | parser.add_argument( | 253 | parser.add_argument( |
| 254 | "--valid_set_size", | ||
| 255 | type=int, | ||
| 256 | default=None, | ||
| 257 | help="Number of images in the validation dataset." | ||
| 258 | ) | ||
| 259 | parser.add_argument( | ||
| 255 | "--train_batch_size", | 260 | "--train_batch_size", |
| 256 | type=int, | 261 | type=int, |
| 257 | default=1, | 262 | default=1, |
| @@ -637,7 +642,7 @@ def main(): | |||
| 637 | size=args.resolution, | 642 | size=args.resolution, |
| 638 | repeats=args.repeats, | 643 | repeats=args.repeats, |
| 639 | center_crop=args.center_crop, | 644 | center_crop=args.center_crop, |
| 640 | valid_set_size=args.sample_batch_size*args.sample_batches, | 645 | valid_set_size=args.valid_set_size, |
| 641 | num_workers=args.dataloader_num_workers, | 646 | num_workers=args.dataloader_num_workers, |
| 642 | collate_fn=collate_fn | 647 | collate_fn=collate_fn |
| 643 | ) | 648 | ) |
