diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
| commit | 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch) | |
| tree | e08741c9df3b30a05ade472da45d7410bbf972ae | |
| parent | Added Lion optimizer (diff) | |
| download | textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2 textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip | |
Update
| -rw-r--r-- | infer.py | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 8 | ||||
| -rw-r--r-- | train_lora.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 6 | ||||
| -rw-r--r-- | training/functional.py | 21 |
5 files changed, 30 insertions, 15 deletions
| @@ -263,6 +263,8 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
| 263 | config, | 263 | config, |
| 264 | solver_p=create_scheduler(config, subscheduler), | 264 | solver_p=create_scheduler(config, subscheduler), |
| 265 | ) | 265 | ) |
| 266 | else: | ||
| 267 | raise ValueError(f"Unknown scheduler \"{scheduler}\"") | ||
| 266 | 268 | ||
| 267 | 269 | ||
| 268 | def create_pipeline(model, dtype): | 270 | def create_pipeline(model, dtype): |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8f0c6ea..e039df0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -287,7 +287,7 @@ def parse_args(): | |||
| 287 | parser.add_argument( | 287 | parser.add_argument( |
| 288 | "--optimizer", | 288 | "--optimizer", |
| 289 | type=str, | 289 | type=str, |
| 290 | default="lion", | 290 | default="adam", |
| 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
| 292 | ) | 292 | ) |
| 293 | parser.add_argument( | 293 | parser.add_argument( |
| @@ -459,7 +459,7 @@ def main(): | |||
| 459 | save_args(output_dir, args) | 459 | save_args(output_dir, args) |
| 460 | 460 | ||
| 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 462 | args.pretrained_model_name_or_path) | 462 | args.pretrained_model_name_or_path, noise_scheduler="deis") |
| 463 | 463 | ||
| 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 465 | tokenizer.set_dropout(args.vector_dropout) | 465 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -513,13 +513,15 @@ def main(): | |||
| 513 | eps=args.adam_epsilon, | 513 | eps=args.adam_epsilon, |
| 514 | amsgrad=args.adam_amsgrad, | 514 | amsgrad=args.adam_amsgrad, |
| 515 | ) | 515 | ) |
| 516 | else: | 516 | elif args.optimizer == 'lion': |
| 517 | try: | 517 | try: |
| 518 | from lion_pytorch import Lion | 518 | from lion_pytorch import Lion |
| 519 | except ImportError: | 519 | except ImportError: |
| 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
| 521 | 521 | ||
| 522 | create_optimizer = partial(Lion, use_triton=True) | 522 | create_optimizer = partial(Lion, use_triton=True) |
| 523 | else: | ||
| 524 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
| 523 | 525 | ||
| 524 | trainer = partial( | 526 | trainer = partial( |
| 525 | train, | 527 | train, |
diff --git a/train_lora.py b/train_lora.py index 368c29b..db5330a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -247,7 +247,7 @@ def parse_args(): | |||
| 247 | parser.add_argument( | 247 | parser.add_argument( |
| 248 | "--optimizer", | 248 | "--optimizer", |
| 249 | type=str, | 249 | type=str, |
| 250 | default="lion", | 250 | default="adam", |
| 251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
| 252 | ) | 252 | ) |
| 253 | parser.add_argument( | 253 | parser.add_argument( |
| @@ -419,7 +419,7 @@ def main(): | |||
| 419 | save_args(output_dir, args) | 419 | save_args(output_dir, args) |
| 420 | 420 | ||
| 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 422 | args.pretrained_model_name_or_path) | 422 | args.pretrained_model_name_or_path, noise_scheduler="deis") |
| 423 | 423 | ||
| 424 | vae.enable_slicing() | 424 | vae.enable_slicing() |
| 425 | vae.set_use_memory_efficient_attention_xformers(True) | 425 | vae.set_use_memory_efficient_attention_xformers(True) |
| @@ -488,13 +488,15 @@ def main(): | |||
| 488 | eps=args.adam_epsilon, | 488 | eps=args.adam_epsilon, |
| 489 | amsgrad=args.adam_amsgrad, | 489 | amsgrad=args.adam_amsgrad, |
| 490 | ) | 490 | ) |
| 491 | else: | 491 | elif args.optimizer == 'lion': |
| 492 | try: | 492 | try: |
| 493 | from lion_pytorch import Lion | 493 | from lion_pytorch import Lion |
| 494 | except ImportError: | 494 | except ImportError: |
| 495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
| 496 | 496 | ||
| 497 | create_optimizer = partial(Lion, use_triton=True) | 497 | create_optimizer = partial(Lion, use_triton=True) |
| 498 | else: | ||
| 499 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
| 498 | 500 | ||
| 499 | trainer = partial( | 501 | trainer = partial( |
| 500 | train, | 502 | train, |
diff --git a/train_ti.py b/train_ti.py index 507d710..12e3644 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -292,7 +292,7 @@ def parse_args(): | |||
| 292 | parser.add_argument( | 292 | parser.add_argument( |
| 293 | "--optimizer", | 293 | "--optimizer", |
| 294 | type=str, | 294 | type=str, |
| 295 | default="lion", | 295 | default="adam", |
| 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
| 297 | ) | 297 | ) |
| 298 | parser.add_argument( | 298 | parser.add_argument( |
| @@ -586,13 +586,15 @@ def main(): | |||
| 586 | eps=args.adam_epsilon, | 586 | eps=args.adam_epsilon, |
| 587 | amsgrad=args.adam_amsgrad, | 587 | amsgrad=args.adam_amsgrad, |
| 588 | ) | 588 | ) |
| 589 | else: | 589 | elif args.optimizer == 'lion': |
| 590 | try: | 590 | try: |
| 591 | from lion_pytorch import Lion | 591 | from lion_pytorch import Lion |
| 592 | except ImportError: | 592 | except ImportError: |
| 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
| 594 | 594 | ||
| 595 | create_optimizer = partial(Lion, use_triton=True) | 595 | create_optimizer = partial(Lion, use_triton=True) |
| 596 | else: | ||
| 597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
| 596 | 598 | ||
| 597 | checkpoint_output_dir = output_dir/"checkpoints" | 599 | checkpoint_output_dir = output_dir/"checkpoints" |
| 598 | 600 | ||
diff --git a/training/functional.py b/training/functional.py index 4d0cf0e..85dd884 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
| 12 | 12 | ||
| 13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| 14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
| 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler | 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin |
| 16 | 16 | ||
| 17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | 18 | from PIL import Image |
| @@ -74,12 +74,19 @@ def make_grid(images, rows, cols): | |||
| 74 | return grid | 74 | return grid |
| 75 | 75 | ||
| 76 | 76 | ||
| 77 | def get_models(pretrained_model_name_or_path: str): | 77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): |
| 78 | if noise_scheduler == "deis": | ||
| 79 | noise_scheduler_cls = DEISMultistepScheduler | ||
| 80 | elif noise_scheduler == "ddpm": | ||
| 81 | noise_scheduler_cls = DDPMScheduler | ||
| 82 | else: | ||
| 83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
| 84 | |||
| 78 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 79 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| 81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
| 82 | noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
| 83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 84 | pretrained_model_name_or_path, subfolder='scheduler') | 91 | pretrained_model_name_or_path, subfolder='scheduler') |
| 85 | 92 | ||
| @@ -94,7 +101,7 @@ def save_samples( | |||
| 94 | text_encoder: CLIPTextModel, | 101 | text_encoder: CLIPTextModel, |
| 95 | tokenizer: MultiCLIPTokenizer, | 102 | tokenizer: MultiCLIPTokenizer, |
| 96 | vae: AutoencoderKL, | 103 | vae: AutoencoderKL, |
| 97 | sample_scheduler: UniPCMultistepScheduler, | 104 | sample_scheduler: SchedulerMixin, |
| 98 | train_dataloader: DataLoader, | 105 | train_dataloader: DataLoader, |
| 99 | val_dataloader: Optional[DataLoader], | 106 | val_dataloader: Optional[DataLoader], |
| 100 | output_dir: Path, | 107 | output_dir: Path, |
| @@ -181,7 +188,7 @@ def generate_class_images( | |||
| 181 | vae: AutoencoderKL, | 188 | vae: AutoencoderKL, |
| 182 | unet: UNet2DConditionModel, | 189 | unet: UNet2DConditionModel, |
| 183 | tokenizer: MultiCLIPTokenizer, | 190 | tokenizer: MultiCLIPTokenizer, |
| 184 | sample_scheduler: UniPCMultistepScheduler, | 191 | sample_scheduler: SchedulerMixin, |
| 185 | train_dataset: VlpnDataset, | 192 | train_dataset: VlpnDataset, |
| 186 | sample_batch_size: int, | 193 | sample_batch_size: int, |
| 187 | sample_image_size: int, | 194 | sample_image_size: int, |
| @@ -252,7 +259,7 @@ def add_placeholder_tokens( | |||
| 252 | 259 | ||
| 253 | def loss_step( | 260 | def loss_step( |
| 254 | vae: AutoencoderKL, | 261 | vae: AutoencoderKL, |
| 255 | noise_scheduler: DEISMultistepScheduler, | 262 | noise_scheduler: SchedulerMixin, |
| 256 | unet: UNet2DConditionModel, | 263 | unet: UNet2DConditionModel, |
| 257 | text_encoder: CLIPTextModel, | 264 | text_encoder: CLIPTextModel, |
| 258 | with_prior_preservation: bool, | 265 | with_prior_preservation: bool, |
| @@ -552,7 +559,7 @@ def train( | |||
| 552 | unet: UNet2DConditionModel, | 559 | unet: UNet2DConditionModel, |
| 553 | text_encoder: CLIPTextModel, | 560 | text_encoder: CLIPTextModel, |
| 554 | vae: AutoencoderKL, | 561 | vae: AutoencoderKL, |
| 555 | noise_scheduler: DEISMultistepScheduler, | 562 | noise_scheduler: SchedulerMixin, |
| 556 | dtype: torch.dtype, | 563 | dtype: torch.dtype, |
| 557 | seed: int, | 564 | seed: int, |
| 558 | project: str, | 565 | project: str, |
