From 8abbd633d8ee7500058b2f1f69a6d6611b5a4450 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Mar 2023 15:18:20 +0100 Subject: Test: https://arxiv.org/pdf/2303.09556.pdf --- environment.yaml | 46 ++++++++++++++++++++++------------------------ train_ti.py | 38 ++++++++++++++++++++++++++------------ training/functional.py | 11 ++++++++--- 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/environment.yaml b/environment.yaml index 018a9ab..9355f37 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,27 +1,25 @@ name: ldd channels: - - pytorch - - nvidia - - xformers/label/dev - - defaults + - pytorch + - nvidia + - xformers/label/dev + - defaults dependencies: - - cudatoolkit=11.7 - - libcufile=1.4.0.31 - - matplotlib=3.6.2 - - numpy=1.23.4 - - pip=22.3.1 - - python=3.10.8 - - pytorch=1.13.1=*cuda* - - torchvision=0.14.1 - - xformers=0.0.17.dev466 - - pip: - - -e . - - -e git+https://github.com/huggingface/diffusers#egg=diffusers - - accelerate==0.16.0 - - bitsandbytes==0.37.0 - - python-slugify>=6.1.2 - - safetensors==0.2.8 - - setuptools==65.6.3 - - test-tube>=0.7.5 - - transformers==4.26.1 - - triton==2.0.0 + - matplotlib=3.6.2 + - numpy=1.23.4 + - pip=22.3.1 + - python=3.10.8 + - pytorch=2.0.0=*cuda11.8* + - torchvision=0.15.0 + # - xformers=0.0.17.dev476 + - pip: + - -e . + - -e git+https://github.com/huggingface/diffusers#egg=diffusers + - accelerate==0.17.1 + - bitsandbytes==0.37.1 + - python-slugify>=6.1.2 + - safetensors==0.3.0 + - setuptools==65.6.3 + - test-tube>=0.7.5 + - transformers==4.27.1 + - triton==2.0.0 diff --git a/train_ti.py b/train_ti.py index 81938c8..fd23517 100644 --- a/train_ti.py +++ b/train_ti.py @@ -61,12 +61,6 @@ def parse_args(): default=None, help="The name of the current project.", ) - parser.add_argument( - "--skip_first", - type=int, - default=0, - help="Tokens to skip training for.", - ) parser.add_argument( "--placeholder_tokens", type=str, @@ -79,6 +73,13 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--alias_tokens", + type=str, + nargs='*', + default=[], + help="Tokens to create an alias for." + ) parser.add_argument( "--num_vectors", type=int, @@ -420,7 +421,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=1e-2, + default=1e2, type=float, help="Embedding decay factor." ) @@ -482,6 +483,9 @@ def parse_args(): if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: + raise ValueError("--alias_tokens must be a list with an even number of items") + if args.sequential: if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) @@ -543,8 +547,8 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() + # vae.set_use_memory_efficient_attention_xformers(True) + # unet.enable_xformers_memory_efficient_attention() # unet = torch.compile(unet) if args.gradient_checkpointing: @@ -559,6 +563,19 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + if len(args.alias_tokens) != 0: + alias_placeholder_tokens = args.alias_tokens[::2] + alias_initializer_tokens = args.alias_tokens[1::2] + + added_tokens, added_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=alias_placeholder_tokens, + initializer_tokens=alias_initializer_tokens + ) + embeddings.persist() + print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -633,9 +650,6 @@ def main(): ) def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): - if i < args.skip_first: - return - if len(placeholder_tokens) == 1: sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" diff --git a/training/functional.py b/training/functional.py index 4565612..2d6553a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -309,8 +309,12 @@ def loss_step( # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise + snr_weights = 1 elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) + snr = target / (1 - target) + snr /= snr + 1 + snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -320,16 +324,17 @@ def loss_step( target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = (snr_weights * loss).mean() acc = (model_pred == target).float().mean() return loss, acc, bsz -- cgit v1.2.3-54-g00ecf