diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
| commit | 9d6c75262b6919758e781b8333428861a5bf7ede (patch) | |
| tree | 72e5814413c18d476813867d87c8360c14aee200 /train_ti.py | |
| parent | Set default dimensions to 768; add config inheritance (diff) | |
| download | textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2 textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip | |
Added learning rate finder
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 174 |
1 files changed, 87 insertions, 87 deletions
diff --git a/train_ti.py b/train_ti.py index 6e30ac3..ab00b60 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,10 +1,8 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | 2 | import itertools |
| 3 | import math | 3 | import math |
| 4 | import os | ||
| 5 | import datetime | 4 | import datetime |
| 6 | import logging | 5 | import logging |
| 7 | import json | ||
| 8 | from pathlib import Path | 6 | from pathlib import Path |
| 9 | 7 | ||
| 10 | import torch | 8 | import torch |
| @@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config | |||
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 23 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 24 | from training.optimization import get_one_cycle_schedule |
| 25 | from training.lr import LRFinder | ||
| 27 | from training.ti import patch_trainable_embeddings | 26 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 27 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params |
| 29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
| @@ -173,6 +172,11 @@ def parse_args(): | |||
| 173 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 172 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| 174 | ) | 173 | ) |
| 175 | parser.add_argument( | 174 | parser.add_argument( |
| 175 | "--find_lr", | ||
| 176 | action="store_true", | ||
| 177 | help="Automatically find a learning rate (no training).", | ||
| 178 | ) | ||
| 179 | parser.add_argument( | ||
| 176 | "--learning_rate", | 180 | "--learning_rate", |
| 177 | type=float, | 181 | type=float, |
| 178 | default=1e-4, | 182 | default=1e-4, |
| @@ -225,7 +229,7 @@ def parse_args(): | |||
| 225 | parser.add_argument( | 229 | parser.add_argument( |
| 226 | "--adam_weight_decay", | 230 | "--adam_weight_decay", |
| 227 | type=float, | 231 | type=float, |
| 228 | default=0, | 232 | default=1e-2, |
| 229 | help="Weight decay to use." | 233 | help="Weight decay to use." |
| 230 | ) | 234 | ) |
| 231 | parser.add_argument( | 235 | parser.add_argument( |
| @@ -447,16 +451,23 @@ def main(): | |||
| 447 | global_step_offset = args.global_step | 451 | global_step_offset = args.global_step |
| 448 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 452 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 449 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 453 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 450 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 451 | 454 | ||
| 452 | accelerator = Accelerator( | 455 | if args.find_lr: |
| 453 | log_with=LoggerType.TENSORBOARD, | 456 | accelerator = Accelerator( |
| 454 | logging_dir=f"{basepath}", | 457 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 455 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 458 | mixed_precision=args.mixed_precision |
| 456 | mixed_precision=args.mixed_precision | 459 | ) |
| 457 | ) | 460 | else: |
| 461 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 462 | |||
| 463 | accelerator = Accelerator( | ||
| 464 | log_with=LoggerType.TENSORBOARD, | ||
| 465 | logging_dir=f"{basepath}", | ||
| 466 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 467 | mixed_precision=args.mixed_precision | ||
| 468 | ) | ||
| 458 | 469 | ||
| 459 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 470 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 460 | 471 | ||
| 461 | args.seed = args.seed or (torch.random.seed() >> 32) | 472 | args.seed = args.seed or (torch.random.seed() >> 32) |
| 462 | set_seed(args.seed) | 473 | set_seed(args.seed) |
| @@ -537,6 +548,9 @@ def main(): | |||
| 537 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
| 538 | ) | 549 | ) |
| 539 | 550 | ||
| 551 | if args.find_lr: | ||
| 552 | args.learning_rate = 1e2 | ||
| 553 | |||
| 540 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 541 | if args.use_8bit_adam: | 555 | if args.use_8bit_adam: |
| 542 | try: | 556 | try: |
| @@ -671,7 +685,9 @@ def main(): | |||
| 671 | 685 | ||
| 672 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 686 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps |
| 673 | 687 | ||
| 674 | if args.lr_scheduler == "one_cycle": | 688 | if args.find_lr: |
| 689 | lr_scheduler = None | ||
| 690 | elif args.lr_scheduler == "one_cycle": | ||
| 675 | lr_scheduler = get_one_cycle_schedule( | 691 | lr_scheduler = get_one_cycle_schedule( |
| 676 | optimizer=optimizer, | 692 | optimizer=optimizer, |
| 677 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 693 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| @@ -713,6 +729,63 @@ def main(): | |||
| 713 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 729 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 714 | val_steps = num_val_steps_per_epoch * num_epochs | 730 | val_steps = num_val_steps_per_epoch * num_epochs |
| 715 | 731 | ||
| 732 | def loop(batch): | ||
| 733 | # Convert images to latent space | ||
| 734 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
| 735 | latents = latents * 0.18215 | ||
| 736 | |||
| 737 | # Sample noise that we'll add to the latents | ||
| 738 | noise = torch.randn_like(latents) | ||
| 739 | bsz = latents.shape[0] | ||
| 740 | # Sample a random timestep for each image | ||
| 741 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 742 | (bsz,), device=latents.device) | ||
| 743 | timesteps = timesteps.long() | ||
| 744 | |||
| 745 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 746 | # (this is the forward diffusion process) | ||
| 747 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 748 | |||
| 749 | # Get the text embedding for conditioning | ||
| 750 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 751 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 752 | |||
| 753 | # Predict the noise residual | ||
| 754 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 755 | |||
| 756 | # Get the target for loss depending on the prediction type | ||
| 757 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 758 | target = noise | ||
| 759 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 760 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 761 | else: | ||
| 762 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 763 | |||
| 764 | if args.num_class_images != 0: | ||
| 765 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 766 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 767 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 768 | |||
| 769 | # Compute instance loss | ||
| 770 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 771 | |||
| 772 | # Compute prior loss | ||
| 773 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 774 | |||
| 775 | # Add the prior loss to the instance loss. | ||
| 776 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 777 | else: | ||
| 778 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 779 | |||
| 780 | acc = (model_pred == latents).float().mean() | ||
| 781 | |||
| 782 | return loss, acc, bsz | ||
| 783 | |||
| 784 | if args.find_lr: | ||
| 785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) | ||
| 786 | lr_finder.run() | ||
| 787 | quit() | ||
| 788 | |||
| 716 | # We need to initialize the trackers we use, and also store our configuration. | 789 | # We need to initialize the trackers we use, and also store our configuration. |
| 717 | # The trackers initializes automatically on the main process. | 790 | # The trackers initializes automatically on the main process. |
| 718 | if accelerator.is_main_process: | 791 | if accelerator.is_main_process: |
| @@ -786,54 +859,7 @@ def main(): | |||
| 786 | 859 | ||
| 787 | for step, batch in enumerate(train_dataloader): | 860 | for step, batch in enumerate(train_dataloader): |
| 788 | with accelerator.accumulate(text_encoder): | 861 | with accelerator.accumulate(text_encoder): |
| 789 | # Convert images to latent space | 862 | loss, acc, bsz = loop(batch) |
| 790 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
| 791 | latents = latents * 0.18215 | ||
| 792 | |||
| 793 | # Sample noise that we'll add to the latents | ||
| 794 | noise = torch.randn_like(latents) | ||
| 795 | bsz = latents.shape[0] | ||
| 796 | # Sample a random timestep for each image | ||
| 797 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 798 | (bsz,), device=latents.device) | ||
| 799 | timesteps = timesteps.long() | ||
| 800 | |||
| 801 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 802 | # (this is the forward diffusion process) | ||
| 803 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 804 | |||
| 805 | # Get the text embedding for conditioning | ||
| 806 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 807 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 808 | |||
| 809 | # Predict the noise residual | ||
| 810 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 811 | |||
| 812 | # Get the target for loss depending on the prediction type | ||
| 813 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 814 | target = noise | ||
| 815 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 816 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 817 | else: | ||
| 818 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 819 | |||
| 820 | if args.num_class_images != 0: | ||
| 821 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 822 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 823 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 824 | |||
| 825 | # Compute instance loss | ||
| 826 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 827 | |||
| 828 | # Compute prior loss | ||
| 829 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 830 | |||
| 831 | # Add the prior loss to the instance loss. | ||
| 832 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 833 | else: | ||
| 834 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 835 | |||
| 836 | acc = (model_pred == latents).float().mean() | ||
| 837 | 863 | ||
| 838 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
| 839 | 865 | ||
| @@ -873,33 +899,7 @@ def main(): | |||
| 873 | 899 | ||
| 874 | with torch.inference_mode(): | 900 | with torch.inference_mode(): |
| 875 | for step, batch in enumerate(val_dataloader): | 901 | for step, batch in enumerate(val_dataloader): |
| 876 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 902 | loss, acc, bsz = loop(batch) |
| 877 | latents = latents * 0.18215 | ||
| 878 | |||
| 879 | noise = torch.randn_like(latents) | ||
| 880 | bsz = latents.shape[0] | ||
| 881 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 882 | (bsz,), device=latents.device) | ||
| 883 | timesteps = timesteps.long() | ||
| 884 | |||
| 885 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 886 | |||
| 887 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 888 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 889 | |||
| 890 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 891 | |||
| 892 | # Get the target for loss depending on the prediction type | ||
| 893 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 894 | target = noise | ||
| 895 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 896 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 897 | else: | ||
| 898 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 899 | |||
| 900 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 901 | |||
| 902 | acc = (model_pred == latents).float().mean() | ||
| 903 | 903 | ||
| 904 | avg_loss_val.update(loss.detach_(), bsz) | 904 | avg_loss_val.update(loss.detach_(), bsz) |
| 905 | avg_acc_val.update(acc.detach_(), bsz) | 905 | avg_acc_val.update(acc.detach_(), bsz) |
