summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py2
-rw-r--r--train_dreambooth.py8
-rw-r--r--train_lora.py8
-rw-r--r--train_ti.py6
-rw-r--r--training/functional.py21
5 files changed, 30 insertions, 15 deletions
diff --git a/infer.py b/infer.py
index 8910e68..07dcd22 100644
--- a/infer.py
+++ b/infer.py
@@ -263,6 +263,8 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None)
263 config, 263 config,
264 solver_p=create_scheduler(config, subscheduler), 264 solver_p=create_scheduler(config, subscheduler),
265 ) 265 )
266 else:
267 raise ValueError(f"Unknown scheduler \"{scheduler}\"")
266 268
267 269
268def create_pipeline(model, dtype): 270def create_pipeline(model, dtype):
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8f0c6ea..e039df0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -287,7 +287,7 @@ def parse_args():
287 parser.add_argument( 287 parser.add_argument(
288 "--optimizer", 288 "--optimizer",
289 type=str, 289 type=str,
290 default="lion", 290 default="adam",
291 help='Optimizer to use ["adam", "adam8bit", "lion"]' 291 help='Optimizer to use ["adam", "adam8bit", "lion"]'
292 ) 292 )
293 parser.add_argument( 293 parser.add_argument(
@@ -459,7 +459,7 @@ def main():
459 save_args(output_dir, args) 459 save_args(output_dir, args)
460 460
461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
462 args.pretrained_model_name_or_path) 462 args.pretrained_model_name_or_path, noise_scheduler="deis")
463 463
464 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 464 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
465 tokenizer.set_dropout(args.vector_dropout) 465 tokenizer.set_dropout(args.vector_dropout)
@@ -513,13 +513,15 @@ def main():
513 eps=args.adam_epsilon, 513 eps=args.adam_epsilon,
514 amsgrad=args.adam_amsgrad, 514 amsgrad=args.adam_amsgrad,
515 ) 515 )
516 else: 516 elif args.optimizer == 'lion':
517 try: 517 try:
518 from lion_pytorch import Lion 518 from lion_pytorch import Lion
519 except ImportError: 519 except ImportError:
520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") 520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
521 521
522 create_optimizer = partial(Lion, use_triton=True) 522 create_optimizer = partial(Lion, use_triton=True)
523 else:
524 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
523 525
524 trainer = partial( 526 trainer = partial(
525 train, 527 train,
diff --git a/train_lora.py b/train_lora.py
index 368c29b..db5330a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -247,7 +247,7 @@ def parse_args():
247 parser.add_argument( 247 parser.add_argument(
248 "--optimizer", 248 "--optimizer",
249 type=str, 249 type=str,
250 default="lion", 250 default="adam",
251 help='Optimizer to use ["adam", "adam8bit", "lion"]' 251 help='Optimizer to use ["adam", "adam8bit", "lion"]'
252 ) 252 )
253 parser.add_argument( 253 parser.add_argument(
@@ -419,7 +419,7 @@ def main():
419 save_args(output_dir, args) 419 save_args(output_dir, args)
420 420
421 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 421 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
422 args.pretrained_model_name_or_path) 422 args.pretrained_model_name_or_path, noise_scheduler="deis")
423 423
424 vae.enable_slicing() 424 vae.enable_slicing()
425 vae.set_use_memory_efficient_attention_xformers(True) 425 vae.set_use_memory_efficient_attention_xformers(True)
@@ -488,13 +488,15 @@ def main():
488 eps=args.adam_epsilon, 488 eps=args.adam_epsilon,
489 amsgrad=args.adam_amsgrad, 489 amsgrad=args.adam_amsgrad,
490 ) 490 )
491 else: 491 elif args.optimizer == 'lion':
492 try: 492 try:
493 from lion_pytorch import Lion 493 from lion_pytorch import Lion
494 except ImportError: 494 except ImportError:
495 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") 495 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
496 496
497 create_optimizer = partial(Lion, use_triton=True) 497 create_optimizer = partial(Lion, use_triton=True)
498 else:
499 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
498 500
499 trainer = partial( 501 trainer = partial(
500 train, 502 train,
diff --git a/train_ti.py b/train_ti.py
index 507d710..12e3644 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -292,7 +292,7 @@ def parse_args():
292 parser.add_argument( 292 parser.add_argument(
293 "--optimizer", 293 "--optimizer",
294 type=str, 294 type=str,
295 default="lion", 295 default="adam",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]' 296 help='Optimizer to use ["adam", "adam8bit", "lion"]'
297 ) 297 )
298 parser.add_argument( 298 parser.add_argument(
@@ -586,13 +586,15 @@ def main():
586 eps=args.adam_epsilon, 586 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad, 587 amsgrad=args.adam_amsgrad,
588 ) 588 )
589 else: 589 elif args.optimizer == 'lion':
590 try: 590 try:
591 from lion_pytorch import Lion 591 from lion_pytorch import Lion
592 except ImportError: 592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") 593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594 594
595 create_optimizer = partial(Lion, use_triton=True) 595 create_optimizer = partial(Lion, use_triton=True)
596 else:
597 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
596 598
597 checkpoint_output_dir = output_dir/"checkpoints" 599 checkpoint_output_dir = output_dir/"checkpoints"
598 600
diff --git a/training/functional.py b/training/functional.py
index 4d0cf0e..85dd884 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler 15from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -74,12 +74,19 @@ def make_grid(images, rows, cols):
74 return grid 74 return grid
75 75
76 76
77def get_models(pretrained_model_name_or_path: str): 77def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"):
78 if noise_scheduler == "deis":
79 noise_scheduler_cls = DEISMultistepScheduler
80 elif noise_scheduler == "ddpm":
81 noise_scheduler_cls = DDPMScheduler
82 else:
83 raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}")
84
78 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 85 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
79 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 86 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
80 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 87 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
81 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 88 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
82 noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 89 noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
83 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 90 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
84 pretrained_model_name_or_path, subfolder='scheduler') 91 pretrained_model_name_or_path, subfolder='scheduler')
85 92
@@ -94,7 +101,7 @@ def save_samples(
94 text_encoder: CLIPTextModel, 101 text_encoder: CLIPTextModel,
95 tokenizer: MultiCLIPTokenizer, 102 tokenizer: MultiCLIPTokenizer,
96 vae: AutoencoderKL, 103 vae: AutoencoderKL,
97 sample_scheduler: UniPCMultistepScheduler, 104 sample_scheduler: SchedulerMixin,
98 train_dataloader: DataLoader, 105 train_dataloader: DataLoader,
99 val_dataloader: Optional[DataLoader], 106 val_dataloader: Optional[DataLoader],
100 output_dir: Path, 107 output_dir: Path,
@@ -181,7 +188,7 @@ def generate_class_images(
181 vae: AutoencoderKL, 188 vae: AutoencoderKL,
182 unet: UNet2DConditionModel, 189 unet: UNet2DConditionModel,
183 tokenizer: MultiCLIPTokenizer, 190 tokenizer: MultiCLIPTokenizer,
184 sample_scheduler: UniPCMultistepScheduler, 191 sample_scheduler: SchedulerMixin,
185 train_dataset: VlpnDataset, 192 train_dataset: VlpnDataset,
186 sample_batch_size: int, 193 sample_batch_size: int,
187 sample_image_size: int, 194 sample_image_size: int,
@@ -252,7 +259,7 @@ def add_placeholder_tokens(
252 259
253def loss_step( 260def loss_step(
254 vae: AutoencoderKL, 261 vae: AutoencoderKL,
255 noise_scheduler: DEISMultistepScheduler, 262 noise_scheduler: SchedulerMixin,
256 unet: UNet2DConditionModel, 263 unet: UNet2DConditionModel,
257 text_encoder: CLIPTextModel, 264 text_encoder: CLIPTextModel,
258 with_prior_preservation: bool, 265 with_prior_preservation: bool,
@@ -552,7 +559,7 @@ def train(
552 unet: UNet2DConditionModel, 559 unet: UNet2DConditionModel,
553 text_encoder: CLIPTextModel, 560 text_encoder: CLIPTextModel,
554 vae: AutoencoderKL, 561 vae: AutoencoderKL,
555 noise_scheduler: DEISMultistepScheduler, 562 noise_scheduler: SchedulerMixin,
556 dtype: torch.dtype, 563 dtype: torch.dtype,
557 seed: int, 564 seed: int,
558 project: str, 565 project: str,