diff options
| -rw-r--r-- | environment.yaml | 46 | ||||
| -rw-r--r-- | train_ti.py | 38 | ||||
| -rw-r--r-- | training/functional.py | 11 |
3 files changed, 56 insertions, 39 deletions
diff --git a/environment.yaml b/environment.yaml index 018a9ab..9355f37 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -1,27 +1,25 @@ | |||
| 1 | name: ldd | 1 | name: ldd |
| 2 | channels: | 2 | channels: |
| 3 | - pytorch | 3 | - pytorch |
| 4 | - nvidia | 4 | - nvidia |
| 5 | - xformers/label/dev | 5 | - xformers/label/dev |
| 6 | - defaults | 6 | - defaults |
| 7 | dependencies: | 7 | dependencies: |
| 8 | - cudatoolkit=11.7 | 8 | - matplotlib=3.6.2 |
| 9 | - libcufile=1.4.0.31 | 9 | - numpy=1.23.4 |
| 10 | - matplotlib=3.6.2 | 10 | - pip=22.3.1 |
| 11 | - numpy=1.23.4 | 11 | - python=3.10.8 |
| 12 | - pip=22.3.1 | 12 | - pytorch=2.0.0=*cuda11.8* |
| 13 | - python=3.10.8 | 13 | - torchvision=0.15.0 |
| 14 | - pytorch=1.13.1=*cuda* | 14 | # - xformers=0.0.17.dev476 |
| 15 | - torchvision=0.14.1 | 15 | - pip: |
| 16 | - xformers=0.0.17.dev466 | 16 | - -e . |
| 17 | - pip: | 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 18 | - -e . | 18 | - accelerate==0.17.1 |
| 19 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 19 | - bitsandbytes==0.37.1 |
| 20 | - accelerate==0.16.0 | 20 | - python-slugify>=6.1.2 |
| 21 | - bitsandbytes==0.37.0 | 21 | - safetensors==0.3.0 |
| 22 | - python-slugify>=6.1.2 | 22 | - setuptools==65.6.3 |
| 23 | - safetensors==0.2.8 | 23 | - test-tube>=0.7.5 |
| 24 | - setuptools==65.6.3 | 24 | - transformers==4.27.1 |
| 25 | - test-tube>=0.7.5 | 25 | - triton==2.0.0 |
| 26 | - transformers==4.26.1 | ||
| 27 | - triton==2.0.0 | ||
diff --git a/train_ti.py b/train_ti.py index 81938c8..fd23517 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -62,12 +62,6 @@ def parse_args(): | |||
| 62 | help="The name of the current project.", | 62 | help="The name of the current project.", |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--skip_first", | ||
| 66 | type=int, | ||
| 67 | default=0, | ||
| 68 | help="Tokens to skip training for.", | ||
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 71 | "--placeholder_tokens", | 65 | "--placeholder_tokens", |
| 72 | type=str, | 66 | type=str, |
| 73 | nargs='*', | 67 | nargs='*', |
| @@ -80,6 +74,13 @@ def parse_args(): | |||
| 80 | help="A token to use as initializer word." | 74 | help="A token to use as initializer word." |
| 81 | ) | 75 | ) |
| 82 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--alias_tokens", | ||
| 78 | type=str, | ||
| 79 | nargs='*', | ||
| 80 | default=[], | ||
| 81 | help="Tokens to create an alias for." | ||
| 82 | ) | ||
| 83 | parser.add_argument( | ||
| 83 | "--num_vectors", | 84 | "--num_vectors", |
| 84 | type=int, | 85 | type=int, |
| 85 | nargs='*', | 86 | nargs='*', |
| @@ -420,7 +421,7 @@ def parse_args(): | |||
| 420 | ) | 421 | ) |
| 421 | parser.add_argument( | 422 | parser.add_argument( |
| 422 | "--emb_decay", | 423 | "--emb_decay", |
| 423 | default=1e-2, | 424 | default=1e2, |
| 424 | type=float, | 425 | type=float, |
| 425 | help="Embedding decay factor." | 426 | help="Embedding decay factor." |
| 426 | ) | 427 | ) |
| @@ -482,6 +483,9 @@ def parse_args(): | |||
| 482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 483 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
| 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 484 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 484 | 485 | ||
| 486 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
| 487 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
| 488 | |||
| 485 | if args.sequential: | 489 | if args.sequential: |
| 486 | if isinstance(args.train_data_template, str): | 490 | if isinstance(args.train_data_template, str): |
| 487 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 491 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
| @@ -543,8 +547,8 @@ def main(): | |||
| 543 | tokenizer.set_dropout(args.vector_dropout) | 547 | tokenizer.set_dropout(args.vector_dropout) |
| 544 | 548 | ||
| 545 | vae.enable_slicing() | 549 | vae.enable_slicing() |
| 546 | vae.set_use_memory_efficient_attention_xformers(True) | 550 | # vae.set_use_memory_efficient_attention_xformers(True) |
| 547 | unet.enable_xformers_memory_efficient_attention() | 551 | # unet.enable_xformers_memory_efficient_attention() |
| 548 | # unet = torch.compile(unet) | 552 | # unet = torch.compile(unet) |
| 549 | 553 | ||
| 550 | if args.gradient_checkpointing: | 554 | if args.gradient_checkpointing: |
| @@ -559,6 +563,19 @@ def main(): | |||
| 559 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 563 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 560 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 564 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 561 | 565 | ||
| 566 | if len(args.alias_tokens) != 0: | ||
| 567 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
| 568 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
| 569 | |||
| 570 | added_tokens, added_ids = add_placeholder_tokens( | ||
| 571 | tokenizer=tokenizer, | ||
| 572 | embeddings=embeddings, | ||
| 573 | placeholder_tokens=alias_placeholder_tokens, | ||
| 574 | initializer_tokens=alias_initializer_tokens | ||
| 575 | ) | ||
| 576 | embeddings.persist() | ||
| 577 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | ||
| 578 | |||
| 562 | if args.scale_lr: | 579 | if args.scale_lr: |
| 563 | args.learning_rate = ( | 580 | args.learning_rate = ( |
| 564 | args.learning_rate * args.gradient_accumulation_steps * | 581 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -633,9 +650,6 @@ def main(): | |||
| 633 | ) | 650 | ) |
| 634 | 651 | ||
| 635 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 652 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 636 | if i < args.skip_first: | ||
| 637 | return | ||
| 638 | |||
| 639 | if len(placeholder_tokens) == 1: | 653 | if len(placeholder_tokens) == 1: |
| 640 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" | 654 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" |
| 641 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" | 655 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" |
diff --git a/training/functional.py b/training/functional.py index 4565612..2d6553a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -309,8 +309,12 @@ def loss_step( | |||
| 309 | # Get the target for loss depending on the prediction type | 309 | # Get the target for loss depending on the prediction type |
| 310 | if noise_scheduler.config.prediction_type == "epsilon": | 310 | if noise_scheduler.config.prediction_type == "epsilon": |
| 311 | target = noise | 311 | target = noise |
| 312 | snr_weights = 1 | ||
| 312 | elif noise_scheduler.config.prediction_type == "v_prediction": | 313 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 313 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 315 | snr = target / (1 - target) | ||
| 316 | snr /= snr + 1 | ||
| 317 | snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) | ||
| 314 | else: | 318 | else: |
| 315 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 319 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 316 | 320 | ||
| @@ -320,16 +324,17 @@ def loss_step( | |||
| 320 | target, target_prior = torch.chunk(target, 2, dim=0) | 324 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 321 | 325 | ||
| 322 | # Compute instance loss | 326 | # Compute instance loss |
| 323 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 327 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 324 | 328 | ||
| 325 | # Compute prior loss | 329 | # Compute prior loss |
| 326 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 330 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
| 327 | 331 | ||
| 328 | # Add the prior loss to the instance loss. | 332 | # Add the prior loss to the instance loss. |
| 329 | loss = loss + prior_loss_weight * prior_loss | 333 | loss = loss + prior_loss_weight * prior_loss |
| 330 | else: | 334 | else: |
| 331 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 335 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 332 | 336 | ||
| 337 | loss = (snr_weights * loss).mean() | ||
| 333 | acc = (model_pred == target).float().mean() | 338 | acc = (model_pred == target).float().mean() |
| 334 | 339 | ||
| 335 | return loss, acc, bsz | 340 | return loss, acc, bsz |
