From 16b92605a59d59c65789c89b54bb97da51908056 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 09:09:50 +0100 Subject: Embedding normalization: Ignore tensors with grad = 0 --- environment.yaml | 3 +-- models/clip/embeddings.py | 8 -------- train_dreambooth.py | 11 ++--------- train_lora.py | 11 ++--------- train_ti.py | 21 ++++++++++----------- training/functional.py | 7 +++++-- training/strategy/ti.py | 15 +++++++++++---- 7 files changed, 31 insertions(+), 45 deletions(-) diff --git a/environment.yaml b/environment.yaml index 57624a3..1e6ac60 100644 --- a/environment.yaml +++ b/environment.yaml @@ -13,16 +13,15 @@ dependencies: - python=3.10.8 - pytorch=1.13.1=*cuda* - torchvision=0.14.1 + - xformers=0.0.17.dev461 - pip: - -e . - -e git+https://github.com/huggingface/diffusers#egg=diffusers - accelerate==0.16.0 - bitsandbytes==0.37.0 - - lion-pytorch==0.0.6 - python-slugify>=6.1.2 - safetensors==0.2.8 - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.26.1 - triton==2.0.0a2 - - xformers==0.0.17.dev451 diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6c41c33..734730e 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -98,14 +98,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds - def normalize(self, target: float = 0.4, lambda_: float = 1.0): - if lambda_ == 0: - return - - w = self.temp_token_embedding.weight - norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) - w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm)) - def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/train_dreambooth.py b/train_dreambooth.py index e039df0..431ff3d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -288,7 +288,7 @@ def parse_args(): "--optimizer", type=str, default="adam", - help='Optimizer to use ["adam", "adam8bit", "lion"]' + help='Optimizer to use ["adam", "adam8bit"]' ) parser.add_argument( "--adam_beta1", @@ -459,7 +459,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, noise_scheduler="deis") + args.pretrained_model_name_or_path) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -513,13 +513,6 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'lion': - try: - from lion_pytorch import Lion - except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") - - create_optimizer = partial(Lion, use_triton=True) else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") diff --git a/train_lora.py b/train_lora.py index db5330a..a06591d 100644 --- a/train_lora.py +++ b/train_lora.py @@ -248,7 +248,7 @@ def parse_args(): "--optimizer", type=str, default="adam", - help='Optimizer to use ["adam", "adam8bit", "lion"]' + help='Optimizer to use ["adam", "adam8bit"]' ) parser.add_argument( "--adam_beta1", @@ -419,7 +419,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, noise_scheduler="deis") + args.pretrained_model_name_or_path) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) @@ -488,13 +488,6 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'lion': - try: - from lion_pytorch import Lion - except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") - - create_optimizer = partial(Lion, use_triton=True) else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") diff --git a/train_ti.py b/train_ti.py index 12e3644..6dc07dd 100644 --- a/train_ti.py +++ b/train_ti.py @@ -86,7 +86,7 @@ def parse_args(): help="Number of vectors per embedding." ) parser.add_argument( - "--simultaneous", + "--sequential", action="store_true", ) parser.add_argument( @@ -293,7 +293,7 @@ def parse_args(): "--optimizer", type=str, default="adam", - help='Optimizer to use ["adam", "adam8bit", "lion"]' + help='Optimizer to use ["adam", "adam8bit"]' ) parser.add_argument( "--adam_beta1", @@ -342,6 +342,11 @@ def parse_args(): default=5, help="How often to save a checkpoint and sample image (in epochs)", ) + parser.add_argument( + "--no_milestone_checkpoints", + action='store_true', + help="If checkpoints are saved on maximum accuracy", + ) parser.add_argument( "--sample_frequency", type=int, @@ -480,7 +485,7 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if not args.simultaneous: + if args.sequential: if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) @@ -586,13 +591,6 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'lion': - try: - from lion_pytorch import Lion - except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") - - create_optimizer = partial(Lion, use_triton=True) else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -615,6 +613,7 @@ def main(): num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, + milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, # -- tokenizer=tokenizer, @@ -715,7 +714,7 @@ def main(): plot_metrics(metrics, metrics_output_file) - if args.simultaneous: + if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: for i, placeholder_token, initializer_token, num_vectors, data_template in zip( diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py @@ -362,6 +362,7 @@ def train_loop( loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, + milestone_checkpoints: bool = True, global_step_offset: int = 0, num_epochs: int = 100, callbacks: TrainingCallbacks = TrainingCallbacks(), @@ -514,7 +515,7 @@ def train_loop( accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg.item() > best_acc_val: + if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() @@ -527,7 +528,7 @@ def train_loop( accs.append(avg_acc_val.avg.item()) else: if accelerator.is_main_process: - if avg_acc.avg.item() > best_acc: + if avg_acc.avg.item() > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() @@ -572,6 +573,7 @@ def train( num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, + milestone_checkpoints: bool = True, global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, @@ -626,6 +628,7 @@ def train( loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, + milestone_checkpoints=milestone_checkpoints, global_step_offset=global_step_offset, num_epochs=num_train_epochs, callbacks=callbacks, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, emb_decay * lr) - ) + lambda_ = emb_decay * lr + + if lambda_ != 0: + w = text_encoder.text_model.embeddings.temp_token_embedding.weight + + mask = torch.zeros(w.size(0), dtype=torch.bool) + mask[text_encoder.text_model.embeddings.temp_token_ids] = True + mask[torch.all(w.grad == 0, dim=1)] = False + + norm = w[mask, :].norm(dim=-1, keepdim=True) + w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) def on_after_optimize(lr: float): if ema_embeddings is not None: -- cgit v1.2.3-54-g00ecf