From 2ad46871e2ead985445da2848a4eb7072b6e48aa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 14 Nov 2022 17:09:58 +0100 Subject: Update --- textual_inversion.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 578c054..999161b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -15,14 +15,13 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule @@ -134,7 +133,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=10000, + default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -251,6 +250,12 @@ def parse_args(): default=1, help="Number of samples to generate per batch", ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) parser.add_argument( "--train_batch_size", type=int, @@ -637,7 +642,7 @@ def main(): size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.sample_batches, + valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) -- cgit v1.2.3-54-g00ecf