From 950f1f6bcbb1a767170cea590b828d8e3cdae882 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Jun 2023 06:48:38 +0200 Subject: Update --- train_dreambooth.py | 32 ++++++++++++++------------------ train_lora.py | 18 ------------------ train_ti.py | 20 +------------------- training/strategy/dreambooth.py | 26 +++++++++++++++++++++++++- 4 files changed, 40 insertions(+), 56 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index d284346..c8f03ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -144,12 +144,6 @@ def parse_args(): default=[], help="Tokens to create an alias for.", ) - parser.add_argument( - "--inverted_initializer_tokens", - type=str, - nargs="*", - help="A token to use as initializer word.", - ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) @@ -498,6 +492,15 @@ def parse_args(): default=0, help="Embedding dropout probability.", ) + parser.add_argument( + "--use_emb_decay", action="store_true", help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", default=1e2, type=float, help="Embedding decay factor." + ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) @@ -554,18 +557,6 @@ def parse_args(): "--placeholder_tokens and --initializer_tokens must have the same number of items" ) - if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( - args.placeholder_tokens - ) - - if ( - isinstance(args.inverted_initializer_tokens, list) - and len(args.inverted_initializer_tokens) != 0 - ): - args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] - args.initializer_tokens += args.inverted_initializer_tokens - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) @@ -875,6 +866,11 @@ def main(): sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, + placeholder_tokens=placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, max_grad_norm=args.max_grad_norm, ) diff --git a/train_lora.py b/train_lora.py index 1ff25ff..fbec009 100644 --- a/train_lora.py +++ b/train_lora.py @@ -157,12 +157,6 @@ def parse_args(): default=[], help="Tokens to create an alias for.", ) - parser.add_argument( - "--inverted_initializer_tokens", - type=str, - nargs="*", - help="A token to use as initializer word.", - ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) @@ -633,18 +627,6 @@ def parse_args(): "--placeholder_tokens and --initializer_tokens must have the same number of items" ) - if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( - args.placeholder_tokens - ) - - if ( - isinstance(args.inverted_initializer_tokens, list) - and len(args.inverted_initializer_tokens) != 0 - ): - args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] - args.initializer_tokens += args.inverted_initializer_tokens - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) diff --git a/train_ti.py b/train_ti.py index 1dbd637..8c63493 100644 --- a/train_ti.py +++ b/train_ti.py @@ -111,12 +111,6 @@ def parse_args(): default=[], help="Tokens to create an alias for.", ) - parser.add_argument( - "--inverted_initializer_tokens", - type=str, - nargs="*", - help="A token to use as initializer word.", - ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) @@ -545,18 +539,6 @@ def parse_args(): "--placeholder_tokens and --initializer_tokens must have the same number of items" ) - if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( - args.placeholder_tokens - ) - - if ( - isinstance(args.inverted_initializer_tokens, list) - and len(args.inverted_initializer_tokens) != 0 - ): - args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] - args.initializer_tokens += args.inverted_initializer_tokens - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) @@ -872,7 +854,7 @@ def main(): optimizer = create_optimizer( text_encoder.text_model.embeddings.token_embedding.parameters(), - lr=learning_rate, + lr=args.learning_rate, ) data_generator = torch.Generator(device="cpu").manual_seed(args.seed) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index dc19ba3..0f64747 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], train_text_encoder_cycles: int, text_encoder_unfreeze_last_n_layers: int = 2, + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks( params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) + if len(placeholder_tokens) != 0 and use_emb_decay: + params = [ + p + for p in text_encoder.text_model.embeddings.parameters() + if p.grad is not None + ] + return torch.stack(params) if len(params) != 0 else None + @torch.no_grad() - def on_after_optimize(_, lrs: dict[str, float]): + def on_after_optimize(w, lrs: dict[str, float]): if ema_unet is not None: ema_unet.step(unet.parameters()) + if w is not None and "emb" in lrs: + lr = lrs["emb"] + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) + def on_log(): if ema_unet is not None: return {"ema_decay": ema_unet.decay} @@ -212,6 +235,7 @@ def dreambooth_prepare( ]: layer.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) # text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2