diff options
| -rw-r--r-- | infer.py | 8 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 303 | ||||
| -rw-r--r-- | train_ti.py | 6 | ||||
| -rw-r--r-- | training/strategy/lora.py | 37 |
5 files changed, 305 insertions, 51 deletions
| @@ -67,8 +67,8 @@ default_cmds = { | |||
| 67 | "batch_num": 1, | 67 | "batch_num": 1, |
| 68 | "steps": 30, | 68 | "steps": 30, |
| 69 | "guidance_scale": 7.0, | 69 | "guidance_scale": 7.0, |
| 70 | "sag_scale": 0.75, | 70 | "sag_scale": 0, |
| 71 | "lora_scale": 0.5, | 71 | "brightness_offset": 0, |
| 72 | "seed": None, | 72 | "seed": None, |
| 73 | "config": None, | 73 | "config": None, |
| 74 | } | 74 | } |
| @@ -192,7 +192,7 @@ def create_cmd_parser(): | |||
| 192 | type=float, | 192 | type=float, |
| 193 | ) | 193 | ) |
| 194 | parser.add_argument( | 194 | parser.add_argument( |
| 195 | "--lora_scale", | 195 | "--brightness_offset", |
| 196 | type=float, | 196 | type=float, |
| 197 | ) | 197 | ) |
| 198 | parser.add_argument( | 198 | parser.add_argument( |
| @@ -392,7 +392,7 @@ def generate(output_dir: Path, pipeline, args): | |||
| 392 | generator=generator, | 392 | generator=generator, |
| 393 | image=init_image, | 393 | image=init_image, |
| 394 | strength=args.image_noise, | 394 | strength=args.image_noise, |
| 395 | # cross_attention_kwargs={"scale": args.lora_scale}, | 395 | brightness_offset=args.brightness_offset, |
| 396 | ).images | 396 | ).images |
| 397 | 397 | ||
| 398 | for j, image in enumerate(images): | 398 | for j, image in enumerate(images): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 127ca50..cfc3208 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -403,7 +403,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 403 | width: Optional[int] = None, | 403 | width: Optional[int] = None, |
| 404 | num_inference_steps: int = 50, | 404 | num_inference_steps: int = 50, |
| 405 | guidance_scale: float = 7.5, | 405 | guidance_scale: float = 7.5, |
| 406 | sag_scale: float = 0.75, | 406 | sag_scale: float = 0.0, |
| 407 | eta: float = 0.0, | 407 | eta: float = 0.0, |
| 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
diff --git a/train_lora.py b/train_lora.py index 1ca56d9..39bf455 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| 5 | from functools import partial | 6 | from functools import partial |
| 6 | import math | 7 | import math |
| @@ -17,9 +18,10 @@ import transformers | |||
| 17 | 18 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 20 | from training.functional import train, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
| 21 | from training.lr import plot_metrics | 22 | from training.lr import plot_metrics |
| 22 | from training.strategy.lora import lora_strategy | 23 | from training.strategy.lora import lora_strategy |
| 24 | from training.strategy.ti import textual_inversion_strategy | ||
| 23 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
| 24 | from training.util import save_args | 26 | from training.util import save_args |
| 25 | 27 | ||
| @@ -81,6 +83,43 @@ def parse_args(): | |||
| 81 | help="The name of the current project.", | 83 | help="The name of the current project.", |
| 82 | ) | 84 | ) |
| 83 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--placeholder_tokens", | ||
| 87 | type=str, | ||
| 88 | nargs='*', | ||
| 89 | help="A token to use as a placeholder for the concept.", | ||
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 92 | "--initializer_tokens", | ||
| 93 | type=str, | ||
| 94 | nargs='*', | ||
| 95 | help="A token to use as initializer word." | ||
| 96 | ) | ||
| 97 | parser.add_argument( | ||
| 98 | "--initializer_noise", | ||
| 99 | type=float, | ||
| 100 | default=0, | ||
| 101 | help="Noise to apply to the initializer word" | ||
| 102 | ) | ||
| 103 | parser.add_argument( | ||
| 104 | "--alias_tokens", | ||
| 105 | type=str, | ||
| 106 | nargs='*', | ||
| 107 | default=[], | ||
| 108 | help="Tokens to create an alias for." | ||
| 109 | ) | ||
| 110 | parser.add_argument( | ||
| 111 | "--inverted_initializer_tokens", | ||
| 112 | type=str, | ||
| 113 | nargs='*', | ||
| 114 | help="A token to use as initializer word." | ||
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 117 | "--num_vectors", | ||
| 118 | type=int, | ||
| 119 | nargs='*', | ||
| 120 | help="Number of vectors per embedding." | ||
| 121 | ) | ||
| 122 | parser.add_argument( | ||
| 84 | "--exclude_collections", | 123 | "--exclude_collections", |
| 85 | type=str, | 124 | type=str, |
| 86 | nargs='*', | 125 | nargs='*', |
| @@ -187,6 +226,16 @@ def parse_args(): | |||
| 187 | default=2000 | 226 | default=2000 |
| 188 | ) | 227 | ) |
| 189 | parser.add_argument( | 228 | parser.add_argument( |
| 229 | "--num_pti_epochs", | ||
| 230 | type=int, | ||
| 231 | default=None | ||
| 232 | ) | ||
| 233 | parser.add_argument( | ||
| 234 | "--num_pti_steps", | ||
| 235 | type=int, | ||
| 236 | default=500 | ||
| 237 | ) | ||
| 238 | parser.add_argument( | ||
| 190 | "--gradient_accumulation_steps", | 239 | "--gradient_accumulation_steps", |
| 191 | type=int, | 240 | type=int, |
| 192 | default=1, | 241 | default=1, |
| @@ -258,6 +307,12 @@ def parse_args(): | |||
| 258 | help="Initial learning rate (after the potential warmup period) to use.", | 307 | help="Initial learning rate (after the potential warmup period) to use.", |
| 259 | ) | 308 | ) |
| 260 | parser.add_argument( | 309 | parser.add_argument( |
| 310 | "--learning_rate_pti", | ||
| 311 | type=float, | ||
| 312 | default=1e-4, | ||
| 313 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 314 | ) | ||
| 315 | parser.add_argument( | ||
| 261 | "--scale_lr", | 316 | "--scale_lr", |
| 262 | action="store_true", | 317 | action="store_true", |
| 263 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 318 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -433,6 +488,23 @@ def parse_args(): | |||
| 433 | help="The weight of prior preservation loss." | 488 | help="The weight of prior preservation loss." |
| 434 | ) | 489 | ) |
| 435 | parser.add_argument( | 490 | parser.add_argument( |
| 491 | "--use_emb_decay", | ||
| 492 | action="store_true", | ||
| 493 | help="Whether to use embedding decay." | ||
| 494 | ) | ||
| 495 | parser.add_argument( | ||
| 496 | "--emb_decay_target", | ||
| 497 | default=0.4, | ||
| 498 | type=float, | ||
| 499 | help="Embedding decay target." | ||
| 500 | ) | ||
| 501 | parser.add_argument( | ||
| 502 | "--emb_decay", | ||
| 503 | default=1e+2, | ||
| 504 | type=float, | ||
| 505 | help="Embedding decay factor." | ||
| 506 | ) | ||
| 507 | parser.add_argument( | ||
| 436 | "--max_grad_norm", | 508 | "--max_grad_norm", |
| 437 | default=1.0, | 509 | default=1.0, |
| 438 | type=float, | 510 | type=float, |
| @@ -464,6 +536,40 @@ def parse_args(): | |||
| 464 | if args.project is None: | 536 | if args.project is None: |
| 465 | raise ValueError("You must specify --project") | 537 | raise ValueError("You must specify --project") |
| 466 | 538 | ||
| 539 | if isinstance(args.placeholder_tokens, str): | ||
| 540 | args.placeholder_tokens = [args.placeholder_tokens] | ||
| 541 | |||
| 542 | if isinstance(args.initializer_tokens, str): | ||
| 543 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
| 544 | |||
| 545 | if len(args.initializer_tokens) == 0: | ||
| 546 | raise ValueError("You must specify --initializer_tokens") | ||
| 547 | |||
| 548 | if len(args.placeholder_tokens) == 0: | ||
| 549 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
| 550 | |||
| 551 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
| 552 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
| 553 | |||
| 554 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 555 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | ||
| 556 | |||
| 557 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | ||
| 558 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 559 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 560 | |||
| 561 | if isinstance(args.num_vectors, int): | ||
| 562 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
| 563 | |||
| 564 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | ||
| 565 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
| 566 | |||
| 567 | if args.alias_tokens is None: | ||
| 568 | args.alias_tokens = [] | ||
| 569 | |||
| 570 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
| 571 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
| 572 | |||
| 467 | if isinstance(args.collection, str): | 573 | if isinstance(args.collection, str): |
| 468 | args.collection = [args.collection] | 574 | args.collection = [args.collection] |
| 469 | 575 | ||
| @@ -544,6 +650,19 @@ def main(): | |||
| 544 | if args.gradient_checkpointing: | 650 | if args.gradient_checkpointing: |
| 545 | unet.enable_gradient_checkpointing() | 651 | unet.enable_gradient_checkpointing() |
| 546 | 652 | ||
| 653 | if len(args.alias_tokens) != 0: | ||
| 654 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
| 655 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
| 656 | |||
| 657 | added_tokens, added_ids = add_placeholder_tokens( | ||
| 658 | tokenizer=tokenizer, | ||
| 659 | embeddings=embeddings, | ||
| 660 | placeholder_tokens=alias_placeholder_tokens, | ||
| 661 | initializer_tokens=alias_initializer_tokens | ||
| 662 | ) | ||
| 663 | embeddings.persist() | ||
| 664 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | ||
| 665 | |||
| 547 | if args.embeddings_dir is not None: | 666 | if args.embeddings_dir is not None: |
| 548 | embeddings_dir = Path(args.embeddings_dir) | 667 | embeddings_dir = Path(args.embeddings_dir) |
| 549 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 668 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| @@ -552,6 +671,19 @@ def main(): | |||
| 552 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 671 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 553 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 672 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 554 | 673 | ||
| 674 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 675 | tokenizer=tokenizer, | ||
| 676 | embeddings=embeddings, | ||
| 677 | placeholder_tokens=args.placeholder_tokens, | ||
| 678 | initializer_tokens=args.initializer_tokens, | ||
| 679 | num_vectors=args.num_vectors, | ||
| 680 | initializer_noise=args.initializer_noise, | ||
| 681 | ) | ||
| 682 | stats = list(zip( | ||
| 683 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | ||
| 684 | )) | ||
| 685 | print(f"Training embeddings: {stats}") | ||
| 686 | |||
| 555 | if args.scale_lr: | 687 | if args.scale_lr: |
| 556 | args.learning_rate_unet = ( | 688 | args.learning_rate_unet = ( |
| 557 | args.learning_rate_unet * args.gradient_accumulation_steps * | 689 | args.learning_rate_unet * args.gradient_accumulation_steps * |
| @@ -561,10 +693,15 @@ def main(): | |||
| 561 | args.learning_rate_text * args.gradient_accumulation_steps * | 693 | args.learning_rate_text * args.gradient_accumulation_steps * |
| 562 | args.train_batch_size * accelerator.num_processes | 694 | args.train_batch_size * accelerator.num_processes |
| 563 | ) | 695 | ) |
| 696 | args.learning_rate_pti = ( | ||
| 697 | args.learning_rate_pti * args.gradient_accumulation_steps * | ||
| 698 | args.train_batch_size * accelerator.num_processes | ||
| 699 | ) | ||
| 564 | 700 | ||
| 565 | if args.find_lr: | 701 | if args.find_lr: |
| 566 | args.learning_rate_unet = 1e-6 | 702 | args.learning_rate_unet = 1e-6 |
| 567 | args.learning_rate_text = 1e-6 | 703 | args.learning_rate_text = 1e-6 |
| 704 | args.learning_rate_pti = 1e-6 | ||
| 568 | args.lr_scheduler = "exponential_growth" | 705 | args.lr_scheduler = "exponential_growth" |
| 569 | 706 | ||
| 570 | if args.optimizer == 'adam8bit': | 707 | if args.optimizer == 'adam8bit': |
| @@ -663,18 +800,25 @@ def main(): | |||
| 663 | accelerator=accelerator, | 800 | accelerator=accelerator, |
| 664 | unet=unet, | 801 | unet=unet, |
| 665 | text_encoder=text_encoder, | 802 | text_encoder=text_encoder, |
| 803 | tokenizer=tokenizer, | ||
| 666 | vae=vae, | 804 | vae=vae, |
| 667 | noise_scheduler=noise_scheduler, | 805 | noise_scheduler=noise_scheduler, |
| 668 | dtype=weight_dtype, | 806 | dtype=weight_dtype, |
| 807 | seed=args.seed, | ||
| 669 | guidance_scale=args.guidance_scale, | 808 | guidance_scale=args.guidance_scale, |
| 670 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 809 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 671 | no_val=args.valid_set_size == 0, | 810 | no_val=args.valid_set_size == 0, |
| 811 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 812 | offset_noise_strength=args.offset_noise_strength, | ||
| 813 | sample_scheduler=sample_scheduler, | ||
| 814 | sample_batch_size=args.sample_batch_size, | ||
| 815 | sample_num_batches=args.sample_batches, | ||
| 816 | sample_num_steps=args.sample_steps, | ||
| 817 | sample_image_size=args.sample_image_size, | ||
| 672 | ) | 818 | ) |
| 673 | 819 | ||
| 674 | checkpoint_output_dir = output_dir / "model" | 820 | create_datamodule = partial( |
| 675 | sample_output_dir = output_dir/"samples" | 821 | VlpnDataModule, |
| 676 | |||
| 677 | datamodule = VlpnDataModule( | ||
| 678 | data_file=args.train_data_file, | 822 | data_file=args.train_data_file, |
| 679 | batch_size=args.train_batch_size, | 823 | batch_size=args.train_batch_size, |
| 680 | tokenizer=tokenizer, | 824 | tokenizer=tokenizer, |
| @@ -693,71 +837,146 @@ def main(): | |||
| 693 | train_set_pad=args.train_set_pad, | 837 | train_set_pad=args.train_set_pad, |
| 694 | valid_set_pad=args.valid_set_pad, | 838 | valid_set_pad=args.valid_set_pad, |
| 695 | seed=args.seed, | 839 | seed=args.seed, |
| 840 | dtype=weight_dtype, | ||
| 841 | ) | ||
| 842 | |||
| 843 | create_lr_scheduler = partial( | ||
| 844 | get_scheduler, | ||
| 845 | args.lr_scheduler, | ||
| 846 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 847 | min_lr=args.lr_min_lr, | ||
| 848 | warmup_func=args.lr_warmup_func, | ||
| 849 | annealing_func=args.lr_annealing_func, | ||
| 850 | warmup_exp=args.lr_warmup_exp, | ||
| 851 | annealing_exp=args.lr_annealing_exp, | ||
| 852 | cycles=args.lr_cycles, | ||
| 853 | end_lr=1e2, | ||
| 854 | warmup_epochs=args.lr_warmup_epochs, | ||
| 855 | mid_point=args.lr_mid_point, | ||
| 856 | ) | ||
| 857 | |||
| 858 | # PTI | ||
| 859 | # -------------------------------------------------------------------------------- | ||
| 860 | |||
| 861 | pti_output_dir = output_dir / "pti" | ||
| 862 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
| 863 | pti_sample_output_dir = pti_output_dir / "samples" | ||
| 864 | |||
| 865 | pti_datamodule = create_datamodule( | ||
| 866 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
| 867 | ) | ||
| 868 | pti_datamodule.setup() | ||
| 869 | |||
| 870 | num_pti_epochs = args.num_pti_epochs | ||
| 871 | pti_sample_frequency = args.sample_frequency | ||
| 872 | if num_pti_epochs is None: | ||
| 873 | num_pti_epochs = math.ceil( | ||
| 874 | args.num_pti_steps / len(pti_datamodule.train_dataset) | ||
| 875 | ) * args.gradient_accumulation_steps | ||
| 876 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | ||
| 877 | |||
| 878 | pti_optimizer = create_optimizer( | ||
| 879 | [ | ||
| 880 | { | ||
| 881 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | ||
| 882 | "lr": args.learning_rate_pti, | ||
| 883 | "weight_decay": 0, | ||
| 884 | }, | ||
| 885 | ] | ||
| 886 | ) | ||
| 887 | |||
| 888 | pti_lr_scheduler = create_lr_scheduler( | ||
| 889 | optimizer=pti_optimizer, | ||
| 890 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
| 891 | train_epochs=num_pti_epochs, | ||
| 892 | ) | ||
| 893 | |||
| 894 | metrics = trainer( | ||
| 895 | strategy=textual_inversion_strategy, | ||
| 896 | project="ti", | ||
| 897 | train_dataloader=pti_datamodule.train_dataloader, | ||
| 898 | val_dataloader=pti_datamodule.val_dataloader, | ||
| 899 | optimizer=pti_optimizer, | ||
| 900 | lr_scheduler=pti_lr_scheduler, | ||
| 901 | num_train_epochs=num_pti_epochs, | ||
| 902 | # -- | ||
| 903 | sample_output_dir=pti_sample_output_dir, | ||
| 904 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
| 905 | sample_frequency=pti_sample_frequency, | ||
| 906 | placeholder_tokens=args.placeholder_tokens, | ||
| 907 | placeholder_token_ids=placeholder_token_ids, | ||
| 908 | use_emb_decay=args.use_emb_decay, | ||
| 909 | emb_decay_target=args.emb_decay_target, | ||
| 910 | emb_decay=args.emb_decay, | ||
| 911 | ) | ||
| 912 | |||
| 913 | plot_metrics(metrics, output_dir/"lr.png") | ||
| 914 | |||
| 915 | # LORA | ||
| 916 | # -------------------------------------------------------------------------------- | ||
| 917 | |||
| 918 | lora_output_dir = output_dir / "pti" | ||
| 919 | lora_checkpoint_output_dir = lora_output_dir / "model" | ||
| 920 | lora_sample_output_dir = lora_output_dir / "samples" | ||
| 921 | |||
| 922 | lora_datamodule = create_datamodule( | ||
| 696 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 923 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 697 | dtype=weight_dtype | ||
| 698 | ) | 924 | ) |
| 699 | datamodule.setup() | 925 | lora_datamodule.setup() |
| 700 | 926 | ||
| 701 | num_train_epochs = args.num_train_epochs | 927 | num_train_epochs = args.num_train_epochs |
| 702 | sample_frequency = args.sample_frequency | 928 | lora_sample_frequency = args.sample_frequency |
| 703 | if num_train_epochs is None: | 929 | if num_train_epochs is None: |
| 704 | num_train_epochs = math.ceil( | 930 | num_train_epochs = math.ceil( |
| 705 | args.num_train_steps / len(datamodule.train_dataset) | 931 | args.num_train_steps / len(lora_datamodule.train_dataset) |
| 706 | ) * args.gradient_accumulation_steps | 932 | ) * args.gradient_accumulation_steps |
| 707 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 933 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
| 708 | 934 | ||
| 709 | optimizer = create_optimizer( | 935 | lora_optimizer = create_optimizer( |
| 710 | [ | 936 | [ |
| 711 | { | 937 | { |
| 712 | "params": unet.parameters(), | 938 | "params": unet.parameters(), |
| 713 | "lr": args.learning_rate_unet, | 939 | "lr": args.learning_rate_unet, |
| 714 | }, | 940 | }, |
| 715 | { | 941 | { |
| 716 | "params": text_encoder.parameters(), | 942 | "params": itertools.chain( |
| 943 | text_encoder.text_model.encoder.parameters(), | ||
| 944 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 945 | ), | ||
| 946 | "lr": args.learning_rate_text, | ||
| 947 | }, | ||
| 948 | { | ||
| 949 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | ||
| 717 | "lr": args.learning_rate_text, | 950 | "lr": args.learning_rate_text, |
| 951 | "weight_decay": 0, | ||
| 718 | }, | 952 | }, |
| 719 | ] | 953 | ] |
| 720 | ) | 954 | ) |
| 721 | 955 | ||
| 722 | lr_scheduler = get_scheduler( | 956 | lora_lr_scheduler = create_lr_scheduler( |
| 723 | args.lr_scheduler, | 957 | optimizer=lora_optimizer, |
| 724 | optimizer=optimizer, | 958 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
| 725 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 726 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 727 | min_lr=args.lr_min_lr, | ||
| 728 | warmup_func=args.lr_warmup_func, | ||
| 729 | annealing_func=args.lr_annealing_func, | ||
| 730 | warmup_exp=args.lr_warmup_exp, | ||
| 731 | annealing_exp=args.lr_annealing_exp, | ||
| 732 | cycles=args.lr_cycles, | ||
| 733 | end_lr=1e2, | ||
| 734 | train_epochs=num_train_epochs, | 959 | train_epochs=num_train_epochs, |
| 735 | warmup_epochs=args.lr_warmup_epochs, | ||
| 736 | mid_point=args.lr_mid_point, | ||
| 737 | ) | 960 | ) |
| 738 | 961 | ||
| 739 | metrics = trainer( | 962 | metrics = trainer( |
| 740 | strategy=lora_strategy, | 963 | strategy=lora_strategy, |
| 741 | project="lora", | 964 | project="lora", |
| 742 | train_dataloader=datamodule.train_dataloader, | 965 | train_dataloader=lora_datamodule.train_dataloader, |
| 743 | val_dataloader=datamodule.val_dataloader, | 966 | val_dataloader=lora_datamodule.val_dataloader, |
| 744 | seed=args.seed, | 967 | optimizer=lora_optimizer, |
| 745 | optimizer=optimizer, | 968 | lr_scheduler=lora_lr_scheduler, |
| 746 | lr_scheduler=lr_scheduler, | ||
| 747 | num_train_epochs=num_train_epochs, | 969 | num_train_epochs=num_train_epochs, |
| 748 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 749 | sample_frequency=sample_frequency, | ||
| 750 | offset_noise_strength=args.offset_noise_strength, | ||
| 751 | # -- | 970 | # -- |
| 752 | tokenizer=tokenizer, | 971 | sample_output_dir=lora_sample_output_dir, |
| 753 | sample_scheduler=sample_scheduler, | 972 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 754 | sample_output_dir=sample_output_dir, | 973 | sample_frequency=lora_sample_frequency, |
| 755 | checkpoint_output_dir=checkpoint_output_dir, | 974 | placeholder_tokens=args.placeholder_tokens, |
| 975 | placeholder_token_ids=placeholder_token_ids, | ||
| 976 | use_emb_decay=args.use_emb_decay, | ||
| 977 | emb_decay_target=args.emb_decay_target, | ||
| 978 | emb_decay=args.emb_decay, | ||
| 756 | max_grad_norm=args.max_grad_norm, | 979 | max_grad_norm=args.max_grad_norm, |
| 757 | sample_batch_size=args.sample_batch_size, | ||
| 758 | sample_num_batches=args.sample_batches, | ||
| 759 | sample_num_steps=args.sample_steps, | ||
| 760 | sample_image_size=args.sample_image_size, | ||
| 761 | ) | 980 | ) |
| 762 | 981 | ||
| 763 | plot_metrics(metrics, output_dir/"lr.png") | 982 | plot_metrics(metrics, output_dir/"lr.png") |
diff --git a/train_ti.py b/train_ti.py index fc0d68c..344b412 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -750,6 +750,7 @@ def main(): | |||
| 750 | accelerator=accelerator, | 750 | accelerator=accelerator, |
| 751 | unet=unet, | 751 | unet=unet, |
| 752 | text_encoder=text_encoder, | 752 | text_encoder=text_encoder, |
| 753 | tokenizer=tokenizer, | ||
| 753 | vae=vae, | 754 | vae=vae, |
| 754 | noise_scheduler=noise_scheduler, | 755 | noise_scheduler=noise_scheduler, |
| 755 | dtype=weight_dtype, | 756 | dtype=weight_dtype, |
| @@ -764,8 +765,6 @@ def main(): | |||
| 764 | global_step_offset=global_step_offset, | 765 | global_step_offset=global_step_offset, |
| 765 | offset_noise_strength=args.offset_noise_strength, | 766 | offset_noise_strength=args.offset_noise_strength, |
| 766 | # -- | 767 | # -- |
| 767 | tokenizer=tokenizer, | ||
| 768 | sample_scheduler=sample_scheduler, | ||
| 769 | checkpoint_output_dir=checkpoint_output_dir, | 768 | checkpoint_output_dir=checkpoint_output_dir, |
| 770 | use_emb_decay=args.use_emb_decay, | 769 | use_emb_decay=args.use_emb_decay, |
| 771 | emb_decay_target=args.emb_decay_target, | 770 | emb_decay_target=args.emb_decay_target, |
| @@ -774,6 +773,7 @@ def main(): | |||
| 774 | ema_inv_gamma=args.ema_inv_gamma, | 773 | ema_inv_gamma=args.ema_inv_gamma, |
| 775 | ema_power=args.ema_power, | 774 | ema_power=args.ema_power, |
| 776 | ema_max_decay=args.ema_max_decay, | 775 | ema_max_decay=args.ema_max_decay, |
| 776 | sample_scheduler=sample_scheduler, | ||
| 777 | sample_batch_size=args.sample_batch_size, | 777 | sample_batch_size=args.sample_batch_size, |
| 778 | sample_num_batches=args.sample_batches, | 778 | sample_num_batches=args.sample_batches, |
| 779 | sample_num_steps=args.sample_steps, | 779 | sample_num_steps=args.sample_steps, |
| @@ -863,9 +863,9 @@ def main(): | |||
| 863 | optimizer=optimizer, | 863 | optimizer=optimizer, |
| 864 | lr_scheduler=lr_scheduler, | 864 | lr_scheduler=lr_scheduler, |
| 865 | num_train_epochs=num_train_epochs, | 865 | num_train_epochs=num_train_epochs, |
| 866 | sample_frequency=sample_frequency, | ||
| 867 | # -- | 866 | # -- |
| 868 | sample_output_dir=sample_output_dir, | 867 | sample_output_dir=sample_output_dir, |
| 868 | sample_frequency=sample_frequency, | ||
| 869 | placeholder_tokens=placeholder_tokens, | 869 | placeholder_tokens=placeholder_tokens, |
| 870 | placeholder_token_ids=placeholder_token_ids, | 870 | placeholder_token_ids=placeholder_token_ids, |
| 871 | ) | 871 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 209785a..d51a2f3 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 14 | from peft import get_peft_model_state_dict | 14 | from peft import get_peft_model_state_dict |
| 15 | from safetensors.torch import save_file | 15 | from safetensors.torch import save_file |
| 16 | 16 | ||
| 17 | from slugify import slugify | ||
| 18 | |||
| 17 | from models.clip.tokenizer import MultiCLIPTokenizer | 19 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 20 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 21 | ||
| @@ -30,6 +32,11 @@ def lora_strategy_callbacks( | |||
| 30 | sample_output_dir: Path, | 32 | sample_output_dir: Path, |
| 31 | checkpoint_output_dir: Path, | 33 | checkpoint_output_dir: Path, |
| 32 | seed: int, | 34 | seed: int, |
| 35 | placeholder_tokens: list[str], | ||
| 36 | placeholder_token_ids: list[list[int]], | ||
| 37 | use_emb_decay: bool = False, | ||
| 38 | emb_decay_target: float = 0.4, | ||
| 39 | emb_decay: float = 1e-2, | ||
| 33 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, |
| 34 | sample_batch_size: int = 1, | 41 | sample_batch_size: int = 1, |
| 35 | sample_num_batches: int = 1, | 42 | sample_num_batches: int = 1, |
| @@ -77,6 +84,22 @@ def lora_strategy_callbacks( | |||
| 77 | max_grad_norm | 84 | max_grad_norm |
| 78 | ) | 85 | ) |
| 79 | 86 | ||
| 87 | if use_emb_decay: | ||
| 88 | return torch.stack([ | ||
| 89 | p | ||
| 90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
| 91 | if p.grad is not None | ||
| 92 | ]) | ||
| 93 | |||
| 94 | @torch.no_grad() | ||
| 95 | def on_after_optimize(w, lr: float): | ||
| 96 | if use_emb_decay: | ||
| 97 | lambda_ = emb_decay * lr | ||
| 98 | |||
| 99 | if lambda_ != 0: | ||
| 100 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 101 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 102 | |||
| 80 | @torch.no_grad() | 103 | @torch.no_grad() |
| 81 | def on_checkpoint(step, postfix): | 104 | def on_checkpoint(step, postfix): |
| 82 | if postfix != "end": | 105 | if postfix != "end": |
| @@ -87,6 +110,12 @@ def lora_strategy_callbacks( | |||
| 87 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 110 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 88 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 111 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 89 | 112 | ||
| 113 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | ||
| 114 | text_encoder_.text_model.embeddings.save_embed( | ||
| 115 | ids, | ||
| 116 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | ||
| 117 | ) | ||
| 118 | |||
| 90 | lora_config = {} | 119 | lora_config = {} |
| 91 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 120 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
| 92 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 121 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
| @@ -126,6 +155,7 @@ def lora_strategy_callbacks( | |||
| 126 | on_train=on_train, | 155 | on_train=on_train, |
| 127 | on_eval=on_eval, | 156 | on_eval=on_eval, |
| 128 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |
| 158 | on_after_optimize=on_after_optimize, | ||
| 129 | on_checkpoint=on_checkpoint, | 159 | on_checkpoint=on_checkpoint, |
| 130 | on_sample=on_sample, | 160 | on_sample=on_sample, |
| 131 | ) | 161 | ) |
| @@ -141,7 +171,12 @@ def lora_prepare( | |||
| 141 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 171 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 142 | **kwargs | 172 | **kwargs |
| 143 | ): | 173 | ): |
| 144 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 174 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 175 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
| 176 | |||
| 177 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | ||
| 178 | |||
| 179 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
| 145 | 180 | ||
| 146 | 181 | ||
| 147 | lora_strategy = TrainingStrategy( | 182 | lora_strategy = TrainingStrategy( |
