diff options
-rw-r--r-- | infer.py | 2 | ||||
-rw-r--r-- | train_dreambooth.py | 8 | ||||
-rw-r--r-- | train_lora.py | 8 | ||||
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/functional.py | 21 |
5 files changed, 30 insertions, 15 deletions
@@ -263,6 +263,8 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
263 | config, | 263 | config, |
264 | solver_p=create_scheduler(config, subscheduler), | 264 | solver_p=create_scheduler(config, subscheduler), |
265 | ) | 265 | ) |
266 | else: | ||
267 | raise ValueError(f"Unknown scheduler \"{scheduler}\"") | ||
266 | 268 | ||
267 | 269 | ||
268 | def create_pipeline(model, dtype): | 270 | def create_pipeline(model, dtype): |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8f0c6ea..e039df0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -287,7 +287,7 @@ def parse_args(): | |||
287 | parser.add_argument( | 287 | parser.add_argument( |
288 | "--optimizer", | 288 | "--optimizer", |
289 | type=str, | 289 | type=str, |
290 | default="lion", | 290 | default="adam", |
291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
292 | ) | 292 | ) |
293 | parser.add_argument( | 293 | parser.add_argument( |
@@ -459,7 +459,7 @@ def main(): | |||
459 | save_args(output_dir, args) | 459 | save_args(output_dir, args) |
460 | 460 | ||
461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
462 | args.pretrained_model_name_or_path) | 462 | args.pretrained_model_name_or_path, noise_scheduler="deis") |
463 | 463 | ||
464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
465 | tokenizer.set_dropout(args.vector_dropout) | 465 | tokenizer.set_dropout(args.vector_dropout) |
@@ -513,13 +513,15 @@ def main(): | |||
513 | eps=args.adam_epsilon, | 513 | eps=args.adam_epsilon, |
514 | amsgrad=args.adam_amsgrad, | 514 | amsgrad=args.adam_amsgrad, |
515 | ) | 515 | ) |
516 | else: | 516 | elif args.optimizer == 'lion': |
517 | try: | 517 | try: |
518 | from lion_pytorch import Lion | 518 | from lion_pytorch import Lion |
519 | except ImportError: | 519 | except ImportError: |
520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
521 | 521 | ||
522 | create_optimizer = partial(Lion, use_triton=True) | 522 | create_optimizer = partial(Lion, use_triton=True) |
523 | else: | ||
524 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
523 | 525 | ||
524 | trainer = partial( | 526 | trainer = partial( |
525 | train, | 527 | train, |
diff --git a/train_lora.py b/train_lora.py index 368c29b..db5330a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -247,7 +247,7 @@ def parse_args(): | |||
247 | parser.add_argument( | 247 | parser.add_argument( |
248 | "--optimizer", | 248 | "--optimizer", |
249 | type=str, | 249 | type=str, |
250 | default="lion", | 250 | default="adam", |
251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
252 | ) | 252 | ) |
253 | parser.add_argument( | 253 | parser.add_argument( |
@@ -419,7 +419,7 @@ def main(): | |||
419 | save_args(output_dir, args) | 419 | save_args(output_dir, args) |
420 | 420 | ||
421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
422 | args.pretrained_model_name_or_path) | 422 | args.pretrained_model_name_or_path, noise_scheduler="deis") |
423 | 423 | ||
424 | vae.enable_slicing() | 424 | vae.enable_slicing() |
425 | vae.set_use_memory_efficient_attention_xformers(True) | 425 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -488,13 +488,15 @@ def main(): | |||
488 | eps=args.adam_epsilon, | 488 | eps=args.adam_epsilon, |
489 | amsgrad=args.adam_amsgrad, | 489 | amsgrad=args.adam_amsgrad, |
490 | ) | 490 | ) |
491 | else: | 491 | elif args.optimizer == 'lion': |
492 | try: | 492 | try: |
493 | from lion_pytorch import Lion | 493 | from lion_pytorch import Lion |
494 | except ImportError: | 494 | except ImportError: |
495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
496 | 496 | ||
497 | create_optimizer = partial(Lion, use_triton=True) | 497 | create_optimizer = partial(Lion, use_triton=True) |
498 | else: | ||
499 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
498 | 500 | ||
499 | trainer = partial( | 501 | trainer = partial( |
500 | train, | 502 | train, |
diff --git a/train_ti.py b/train_ti.py index 507d710..12e3644 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -292,7 +292,7 @@ def parse_args(): | |||
292 | parser.add_argument( | 292 | parser.add_argument( |
293 | "--optimizer", | 293 | "--optimizer", |
294 | type=str, | 294 | type=str, |
295 | default="lion", | 295 | default="adam", |
296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
297 | ) | 297 | ) |
298 | parser.add_argument( | 298 | parser.add_argument( |
@@ -586,13 +586,15 @@ def main(): | |||
586 | eps=args.adam_epsilon, | 586 | eps=args.adam_epsilon, |
587 | amsgrad=args.adam_amsgrad, | 587 | amsgrad=args.adam_amsgrad, |
588 | ) | 588 | ) |
589 | else: | 589 | elif args.optimizer == 'lion': |
590 | try: | 590 | try: |
591 | from lion_pytorch import Lion | 591 | from lion_pytorch import Lion |
592 | except ImportError: | 592 | except ImportError: |
593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
594 | 594 | ||
595 | create_optimizer = partial(Lion, use_triton=True) | 595 | create_optimizer = partial(Lion, use_triton=True) |
596 | else: | ||
597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
596 | 598 | ||
597 | checkpoint_output_dir = output_dir/"checkpoints" | 599 | checkpoint_output_dir = output_dir/"checkpoints" |
598 | 600 | ||
diff --git a/training/functional.py b/training/functional.py index 4d0cf0e..85dd884 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
12 | 12 | ||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
15 | from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler | 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin |
16 | 16 | ||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
@@ -74,12 +74,19 @@ def make_grid(images, rows, cols): | |||
74 | return grid | 74 | return grid |
75 | 75 | ||
76 | 76 | ||
77 | def get_models(pretrained_model_name_or_path: str): | 77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): |
78 | if noise_scheduler == "deis": | ||
79 | noise_scheduler_cls = DEISMultistepScheduler | ||
80 | elif noise_scheduler == "ddpm": | ||
81 | noise_scheduler_cls = DDPMScheduler | ||
82 | else: | ||
83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
84 | |||
78 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
79 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
82 | noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
84 | pretrained_model_name_or_path, subfolder='scheduler') | 91 | pretrained_model_name_or_path, subfolder='scheduler') |
85 | 92 | ||
@@ -94,7 +101,7 @@ def save_samples( | |||
94 | text_encoder: CLIPTextModel, | 101 | text_encoder: CLIPTextModel, |
95 | tokenizer: MultiCLIPTokenizer, | 102 | tokenizer: MultiCLIPTokenizer, |
96 | vae: AutoencoderKL, | 103 | vae: AutoencoderKL, |
97 | sample_scheduler: UniPCMultistepScheduler, | 104 | sample_scheduler: SchedulerMixin, |
98 | train_dataloader: DataLoader, | 105 | train_dataloader: DataLoader, |
99 | val_dataloader: Optional[DataLoader], | 106 | val_dataloader: Optional[DataLoader], |
100 | output_dir: Path, | 107 | output_dir: Path, |
@@ -181,7 +188,7 @@ def generate_class_images( | |||
181 | vae: AutoencoderKL, | 188 | vae: AutoencoderKL, |
182 | unet: UNet2DConditionModel, | 189 | unet: UNet2DConditionModel, |
183 | tokenizer: MultiCLIPTokenizer, | 190 | tokenizer: MultiCLIPTokenizer, |
184 | sample_scheduler: UniPCMultistepScheduler, | 191 | sample_scheduler: SchedulerMixin, |
185 | train_dataset: VlpnDataset, | 192 | train_dataset: VlpnDataset, |
186 | sample_batch_size: int, | 193 | sample_batch_size: int, |
187 | sample_image_size: int, | 194 | sample_image_size: int, |
@@ -252,7 +259,7 @@ def add_placeholder_tokens( | |||
252 | 259 | ||
253 | def loss_step( | 260 | def loss_step( |
254 | vae: AutoencoderKL, | 261 | vae: AutoencoderKL, |
255 | noise_scheduler: DEISMultistepScheduler, | 262 | noise_scheduler: SchedulerMixin, |
256 | unet: UNet2DConditionModel, | 263 | unet: UNet2DConditionModel, |
257 | text_encoder: CLIPTextModel, | 264 | text_encoder: CLIPTextModel, |
258 | with_prior_preservation: bool, | 265 | with_prior_preservation: bool, |
@@ -552,7 +559,7 @@ def train( | |||
552 | unet: UNet2DConditionModel, | 559 | unet: UNet2DConditionModel, |
553 | text_encoder: CLIPTextModel, | 560 | text_encoder: CLIPTextModel, |
554 | vae: AutoencoderKL, | 561 | vae: AutoencoderKL, |
555 | noise_scheduler: DEISMultistepScheduler, | 562 | noise_scheduler: SchedulerMixin, |
556 | dtype: torch.dtype, | 563 | dtype: torch.dtype, |
557 | seed: int, | 564 | seed: int, |
558 | project: str, | 565 | project: str, |