summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 578c054..999161b 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -15,14 +15,13 @@ import torch.utils.checkpoint
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 26from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
@@ -134,7 +133,7 @@ def parse_args():
134 parser.add_argument( 133 parser.add_argument(
135 "--max_train_steps", 134 "--max_train_steps",
136 type=int, 135 type=int,
137 default=10000, 136 default=None,
138 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 137 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
139 ) 138 )
140 parser.add_argument( 139 parser.add_argument(
@@ -252,6 +251,12 @@ def parse_args():
252 help="Number of samples to generate per batch", 251 help="Number of samples to generate per batch",
253 ) 252 )
254 parser.add_argument( 253 parser.add_argument(
254 "--valid_set_size",
255 type=int,
256 default=None,
257 help="Number of images in the validation dataset."
258 )
259 parser.add_argument(
255 "--train_batch_size", 260 "--train_batch_size",
256 type=int, 261 type=int,
257 default=1, 262 default=1,
@@ -637,7 +642,7 @@ def main():
637 size=args.resolution, 642 size=args.resolution,
638 repeats=args.repeats, 643 repeats=args.repeats,
639 center_crop=args.center_crop, 644 center_crop=args.center_crop,
640 valid_set_size=args.sample_batch_size*args.sample_batches, 645 valid_set_size=args.valid_set_size,
641 num_workers=args.dataloader_num_workers, 646 num_workers=args.dataloader_num_workers,
642 collate_fn=collate_fn 647 collate_fn=collate_fn
643 ) 648 )