diff options
| -rw-r--r-- | environment.yaml | 3 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 8 | ||||
| -rw-r--r-- | train_dreambooth.py | 11 | ||||
| -rw-r--r-- | train_lora.py | 11 | ||||
| -rw-r--r-- | train_ti.py | 21 | ||||
| -rw-r--r-- | training/functional.py | 7 | ||||
| -rw-r--r-- | training/strategy/ti.py | 15 |
7 files changed, 31 insertions, 45 deletions
diff --git a/environment.yaml b/environment.yaml index 57624a3..1e6ac60 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -13,16 +13,15 @@ dependencies: | |||
| 13 | - python=3.10.8 | 13 | - python=3.10.8 |
| 14 | - pytorch=1.13.1=*cuda* | 14 | - pytorch=1.13.1=*cuda* |
| 15 | - torchvision=0.14.1 | 15 | - torchvision=0.14.1 |
| 16 | - xformers=0.0.17.dev461 | ||
| 16 | - pip: | 17 | - pip: |
| 17 | - -e . | 18 | - -e . |
| 18 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 19 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 19 | - accelerate==0.16.0 | 20 | - accelerate==0.16.0 |
| 20 | - bitsandbytes==0.37.0 | 21 | - bitsandbytes==0.37.0 |
| 21 | - lion-pytorch==0.0.6 | ||
| 22 | - python-slugify>=6.1.2 | 22 | - python-slugify>=6.1.2 |
| 23 | - safetensors==0.2.8 | 23 | - safetensors==0.2.8 |
| 24 | - setuptools==65.6.3 | 24 | - setuptools==65.6.3 |
| 25 | - test-tube>=0.7.5 | 25 | - test-tube>=0.7.5 |
| 26 | - transformers==4.26.1 | 26 | - transformers==4.26.1 |
| 27 | - triton==2.0.0a2 | 27 | - triton==2.0.0a2 |
| 28 | - xformers==0.0.17.dev451 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6c41c33..734730e 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -98,14 +98,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 98 | 98 | ||
| 99 | return embeds | 99 | return embeds |
| 100 | 100 | ||
| 101 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | ||
| 102 | if lambda_ == 0: | ||
| 103 | return | ||
| 104 | |||
| 105 | w = self.temp_token_embedding.weight | ||
| 106 | norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | ||
| 107 | w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm)) | ||
| 108 | |||
| 109 | def forward( | 101 | def forward( |
| 110 | self, | 102 | self, |
| 111 | input_ids: Optional[torch.LongTensor] = None, | 103 | input_ids: Optional[torch.LongTensor] = None, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index e039df0..431ff3d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -288,7 +288,7 @@ def parse_args(): | |||
| 288 | "--optimizer", | 288 | "--optimizer", |
| 289 | type=str, | 289 | type=str, |
| 290 | default="adam", | 290 | default="adam", |
| 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 291 | help='Optimizer to use ["adam", "adam8bit"]' |
| 292 | ) | 292 | ) |
| 293 | parser.add_argument( | 293 | parser.add_argument( |
| 294 | "--adam_beta1", | 294 | "--adam_beta1", |
| @@ -459,7 +459,7 @@ def main(): | |||
| 459 | save_args(output_dir, args) | 459 | save_args(output_dir, args) |
| 460 | 460 | ||
| 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 462 | args.pretrained_model_name_or_path, noise_scheduler="deis") | 462 | args.pretrained_model_name_or_path) |
| 463 | 463 | ||
| 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 465 | tokenizer.set_dropout(args.vector_dropout) | 465 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -513,13 +513,6 @@ def main(): | |||
| 513 | eps=args.adam_epsilon, | 513 | eps=args.adam_epsilon, |
| 514 | amsgrad=args.adam_amsgrad, | 514 | amsgrad=args.adam_amsgrad, |
| 515 | ) | 515 | ) |
| 516 | elif args.optimizer == 'lion': | ||
| 517 | try: | ||
| 518 | from lion_pytorch import Lion | ||
| 519 | except ImportError: | ||
| 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
| 521 | |||
| 522 | create_optimizer = partial(Lion, use_triton=True) | ||
| 523 | else: | 516 | else: |
| 524 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 517 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 525 | 518 | ||
diff --git a/train_lora.py b/train_lora.py index db5330a..a06591d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -248,7 +248,7 @@ def parse_args(): | |||
| 248 | "--optimizer", | 248 | "--optimizer", |
| 249 | type=str, | 249 | type=str, |
| 250 | default="adam", | 250 | default="adam", |
| 251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 251 | help='Optimizer to use ["adam", "adam8bit"]' |
| 252 | ) | 252 | ) |
| 253 | parser.add_argument( | 253 | parser.add_argument( |
| 254 | "--adam_beta1", | 254 | "--adam_beta1", |
| @@ -419,7 +419,7 @@ def main(): | |||
| 419 | save_args(output_dir, args) | 419 | save_args(output_dir, args) |
| 420 | 420 | ||
| 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 422 | args.pretrained_model_name_or_path, noise_scheduler="deis") | 422 | args.pretrained_model_name_or_path) |
| 423 | 423 | ||
| 424 | vae.enable_slicing() | 424 | vae.enable_slicing() |
| 425 | vae.set_use_memory_efficient_attention_xformers(True) | 425 | vae.set_use_memory_efficient_attention_xformers(True) |
| @@ -488,13 +488,6 @@ def main(): | |||
| 488 | eps=args.adam_epsilon, | 488 | eps=args.adam_epsilon, |
| 489 | amsgrad=args.adam_amsgrad, | 489 | amsgrad=args.adam_amsgrad, |
| 490 | ) | 490 | ) |
| 491 | elif args.optimizer == 'lion': | ||
| 492 | try: | ||
| 493 | from lion_pytorch import Lion | ||
| 494 | except ImportError: | ||
| 495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
| 496 | |||
| 497 | create_optimizer = partial(Lion, use_triton=True) | ||
| 498 | else: | 491 | else: |
| 499 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 492 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 500 | 493 | ||
diff --git a/train_ti.py b/train_ti.py index 12e3644..6dc07dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -86,7 +86,7 @@ def parse_args(): | |||
| 86 | help="Number of vectors per embedding." | 86 | help="Number of vectors per embedding." |
| 87 | ) | 87 | ) |
| 88 | parser.add_argument( | 88 | parser.add_argument( |
| 89 | "--simultaneous", | 89 | "--sequential", |
| 90 | action="store_true", | 90 | action="store_true", |
| 91 | ) | 91 | ) |
| 92 | parser.add_argument( | 92 | parser.add_argument( |
| @@ -293,7 +293,7 @@ def parse_args(): | |||
| 293 | "--optimizer", | 293 | "--optimizer", |
| 294 | type=str, | 294 | type=str, |
| 295 | default="adam", | 295 | default="adam", |
| 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit"]' |
| 297 | ) | 297 | ) |
| 298 | parser.add_argument( | 298 | parser.add_argument( |
| 299 | "--adam_beta1", | 299 | "--adam_beta1", |
| @@ -343,6 +343,11 @@ def parse_args(): | |||
| 343 | help="How often to save a checkpoint and sample image (in epochs)", | 343 | help="How often to save a checkpoint and sample image (in epochs)", |
| 344 | ) | 344 | ) |
| 345 | parser.add_argument( | 345 | parser.add_argument( |
| 346 | "--no_milestone_checkpoints", | ||
| 347 | action='store_true', | ||
| 348 | help="If checkpoints are saved on maximum accuracy", | ||
| 349 | ) | ||
| 350 | parser.add_argument( | ||
| 346 | "--sample_frequency", | 351 | "--sample_frequency", |
| 347 | type=int, | 352 | type=int, |
| 348 | default=1, | 353 | default=1, |
| @@ -480,7 +485,7 @@ def parse_args(): | |||
| 480 | if len(args.placeholder_tokens) != len(args.num_vectors): | 485 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 481 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 482 | 487 | ||
| 483 | if not args.simultaneous: | 488 | if args.sequential: |
| 484 | if isinstance(args.train_data_template, str): | 489 | if isinstance(args.train_data_template, str): |
| 485 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 490 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
| 486 | 491 | ||
| @@ -586,13 +591,6 @@ def main(): | |||
| 586 | eps=args.adam_epsilon, | 591 | eps=args.adam_epsilon, |
| 587 | amsgrad=args.adam_amsgrad, | 592 | amsgrad=args.adam_amsgrad, |
| 588 | ) | 593 | ) |
| 589 | elif args.optimizer == 'lion': | ||
| 590 | try: | ||
| 591 | from lion_pytorch import Lion | ||
| 592 | except ImportError: | ||
| 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
| 594 | |||
| 595 | create_optimizer = partial(Lion, use_triton=True) | ||
| 596 | else: | 594 | else: |
| 597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 595 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 598 | 596 | ||
| @@ -615,6 +613,7 @@ def main(): | |||
| 615 | num_train_epochs=args.num_train_epochs, | 613 | num_train_epochs=args.num_train_epochs, |
| 616 | sample_frequency=args.sample_frequency, | 614 | sample_frequency=args.sample_frequency, |
| 617 | checkpoint_frequency=args.checkpoint_frequency, | 615 | checkpoint_frequency=args.checkpoint_frequency, |
| 616 | milestone_checkpoints=not args.no_milestone_checkpoints, | ||
| 618 | global_step_offset=global_step_offset, | 617 | global_step_offset=global_step_offset, |
| 619 | # -- | 618 | # -- |
| 620 | tokenizer=tokenizer, | 619 | tokenizer=tokenizer, |
| @@ -715,7 +714,7 @@ def main(): | |||
| 715 | 714 | ||
| 716 | plot_metrics(metrics, metrics_output_file) | 715 | plot_metrics(metrics, metrics_output_file) |
| 717 | 716 | ||
| 718 | if args.simultaneous: | 717 | if not args.sequential: |
| 719 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 718 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 720 | else: | 719 | else: |
| 721 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 720 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -362,6 +362,7 @@ def train_loop( | |||
| 362 | loss_step: LossCallable, | 362 | loss_step: LossCallable, |
| 363 | sample_frequency: int = 10, | 363 | sample_frequency: int = 10, |
| 364 | checkpoint_frequency: int = 50, | 364 | checkpoint_frequency: int = 50, |
| 365 | milestone_checkpoints: bool = True, | ||
| 365 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
| 366 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
| 367 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 368 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| @@ -514,7 +515,7 @@ def train_loop( | |||
| 514 | accelerator.log(logs, step=global_step) | 515 | accelerator.log(logs, step=global_step) |
| 515 | 516 | ||
| 516 | if accelerator.is_main_process: | 517 | if accelerator.is_main_process: |
| 517 | if avg_acc_val.avg.item() > best_acc_val: | 518 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: |
| 518 | local_progress_bar.clear() | 519 | local_progress_bar.clear() |
| 519 | global_progress_bar.clear() | 520 | global_progress_bar.clear() |
| 520 | 521 | ||
| @@ -527,7 +528,7 @@ def train_loop( | |||
| 527 | accs.append(avg_acc_val.avg.item()) | 528 | accs.append(avg_acc_val.avg.item()) |
| 528 | else: | 529 | else: |
| 529 | if accelerator.is_main_process: | 530 | if accelerator.is_main_process: |
| 530 | if avg_acc.avg.item() > best_acc: | 531 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
| 531 | local_progress_bar.clear() | 532 | local_progress_bar.clear() |
| 532 | global_progress_bar.clear() | 533 | global_progress_bar.clear() |
| 533 | 534 | ||
| @@ -572,6 +573,7 @@ def train( | |||
| 572 | num_train_epochs: int = 100, | 573 | num_train_epochs: int = 100, |
| 573 | sample_frequency: int = 20, | 574 | sample_frequency: int = 20, |
| 574 | checkpoint_frequency: int = 50, | 575 | checkpoint_frequency: int = 50, |
| 576 | milestone_checkpoints: bool = True, | ||
| 575 | global_step_offset: int = 0, | 577 | global_step_offset: int = 0, |
| 576 | with_prior_preservation: bool = False, | 578 | with_prior_preservation: bool = False, |
| 577 | prior_loss_weight: float = 1.0, | 579 | prior_loss_weight: float = 1.0, |
| @@ -626,6 +628,7 @@ def train( | |||
| 626 | loss_step=loss_step_, | 628 | loss_step=loss_step_, |
| 627 | sample_frequency=sample_frequency, | 629 | sample_frequency=sample_frequency, |
| 628 | checkpoint_frequency=checkpoint_frequency, | 630 | checkpoint_frequency=checkpoint_frequency, |
| 631 | milestone_checkpoints=milestone_checkpoints, | ||
| 629 | global_step_offset=global_step_offset, | 632 | global_step_offset=global_step_offset, |
| 630 | num_epochs=num_train_epochs, | 633 | num_epochs=num_train_epochs, |
| 631 | callbacks=callbacks, | 634 | callbacks=callbacks, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
| 118 | if use_emb_decay: | 118 | if use_emb_decay: |
| 119 | text_encoder.text_model.embeddings.normalize( | 119 | lambda_ = emb_decay * lr |
| 120 | emb_decay_target, | 120 | |
| 121 | min(1.0, emb_decay * lr) | 121 | if lambda_ != 0: |
| 122 | ) | 122 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
| 123 | |||
| 124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | ||
| 125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
| 126 | mask[torch.all(w.grad == 0, dim=1)] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 123 | 130 | ||
| 124 | def on_after_optimize(lr: float): | 131 | def on_after_optimize(lr: float): |
| 125 | if ema_embeddings is not None: | 132 | if ema_embeddings is not None: |
