diff options
| -rw-r--r-- | environment.yaml | 1 | ||||
| -rw-r--r-- | train_dreambooth.py | 129 | ||||
| -rw-r--r-- | train_ti.py | 174 | ||||
| -rw-r--r-- | training/lr.py | 115 |
4 files changed, 257 insertions, 162 deletions
diff --git a/environment.yaml b/environment.yaml index 179fa38..c006379 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -5,6 +5,7 @@ channels: | |||
| 5 | - defaults | 5 | - defaults |
| 6 | dependencies: | 6 | dependencies: |
| 7 | - cudatoolkit=11.3 | 7 | - cudatoolkit=11.3 |
| 8 | - matplotlib=3.6.2 | ||
| 8 | - numpy=1.23.4 | 9 | - numpy=1.23.4 |
| 9 | - pip=22.3.1 | 10 | - pip=22.3.1 |
| 10 | - python=3.9.15 | 11 | - python=3.9.15 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 08bc9e0..a62cec9 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -843,6 +843,58 @@ def main(): | |||
| 843 | ) | 843 | ) |
| 844 | global_progress_bar.set_description("Total progress") | 844 | global_progress_bar.set_description("Total progress") |
| 845 | 845 | ||
| 846 | def loop(batch): | ||
| 847 | # Convert images to latent space | ||
| 848 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 849 | latents = latents * 0.18215 | ||
| 850 | |||
| 851 | # Sample noise that we'll add to the latents | ||
| 852 | noise = torch.randn_like(latents) | ||
| 853 | bsz = latents.shape[0] | ||
| 854 | # Sample a random timestep for each image | ||
| 855 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 856 | (bsz,), device=latents.device) | ||
| 857 | timesteps = timesteps.long() | ||
| 858 | |||
| 859 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 860 | # (this is the forward diffusion process) | ||
| 861 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 862 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 863 | |||
| 864 | # Get the text embedding for conditioning | ||
| 865 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 866 | |||
| 867 | # Predict the noise residual | ||
| 868 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 869 | |||
| 870 | # Get the target for loss depending on the prediction type | ||
| 871 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 872 | target = noise | ||
| 873 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 874 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 875 | else: | ||
| 876 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 877 | |||
| 878 | if args.num_class_images != 0: | ||
| 879 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 880 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 881 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 882 | |||
| 883 | # Compute instance loss | ||
| 884 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 885 | |||
| 886 | # Compute prior loss | ||
| 887 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 888 | |||
| 889 | # Add the prior loss to the instance loss. | ||
| 890 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 891 | else: | ||
| 892 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 893 | |||
| 894 | acc = (model_pred == latents).float().mean() | ||
| 895 | |||
| 896 | return loss, acc, bsz | ||
| 897 | |||
| 846 | try: | 898 | try: |
| 847 | for epoch in range(num_epochs): | 899 | for epoch in range(num_epochs): |
| 848 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 900 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -859,54 +911,7 @@ def main(): | |||
| 859 | 911 | ||
| 860 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
| 861 | with accelerator.accumulate(unet): | 913 | with accelerator.accumulate(unet): |
| 862 | # Convert images to latent space | 914 | loss, acc, bsz = loop(batch) |
| 863 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 864 | latents = latents * 0.18215 | ||
| 865 | |||
| 866 | # Sample noise that we'll add to the latents | ||
| 867 | noise = torch.randn_like(latents) | ||
| 868 | bsz = latents.shape[0] | ||
| 869 | # Sample a random timestep for each image | ||
| 870 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 871 | (bsz,), device=latents.device) | ||
| 872 | timesteps = timesteps.long() | ||
| 873 | |||
| 874 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 875 | # (this is the forward diffusion process) | ||
| 876 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 877 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 878 | |||
| 879 | # Get the text embedding for conditioning | ||
| 880 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 881 | |||
| 882 | # Predict the noise residual | ||
| 883 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 884 | |||
| 885 | # Get the target for loss depending on the prediction type | ||
| 886 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 887 | target = noise | ||
| 888 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 889 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 890 | else: | ||
| 891 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 892 | |||
| 893 | if args.num_class_images != 0: | ||
| 894 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 895 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 896 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 897 | |||
| 898 | # Compute instance loss | ||
| 899 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 900 | |||
| 901 | # Compute prior loss | ||
| 902 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 903 | |||
| 904 | # Add the prior loss to the instance loss. | ||
| 905 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 906 | else: | ||
| 907 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 908 | |||
| 909 | acc = (model_pred == latents).float().mean() | ||
| 910 | 915 | ||
| 911 | accelerator.backward(loss) | 916 | accelerator.backward(loss) |
| 912 | 917 | ||
| @@ -960,33 +965,7 @@ def main(): | |||
| 960 | 965 | ||
| 961 | with torch.inference_mode(): | 966 | with torch.inference_mode(): |
| 962 | for step, batch in enumerate(val_dataloader): | 967 | for step, batch in enumerate(val_dataloader): |
| 963 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 968 | loss, acc, bsz = loop(batch) |
| 964 | latents = latents * 0.18215 | ||
| 965 | |||
| 966 | noise = torch.randn_like(latents) | ||
| 967 | bsz = latents.shape[0] | ||
| 968 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 969 | (bsz,), device=latents.device) | ||
| 970 | timesteps = timesteps.long() | ||
| 971 | |||
| 972 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 973 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 974 | |||
| 975 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 976 | |||
| 977 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 978 | |||
| 979 | # Get the target for loss depending on the prediction type | ||
| 980 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 981 | target = noise | ||
| 982 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 983 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 984 | else: | ||
| 985 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 986 | |||
| 987 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 988 | |||
| 989 | acc = (model_pred == latents).float().mean() | ||
| 990 | 969 | ||
| 991 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
| 992 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |
diff --git a/train_ti.py b/train_ti.py index 6e30ac3..ab00b60 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,10 +1,8 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | 2 | import itertools |
| 3 | import math | 3 | import math |
| 4 | import os | ||
| 5 | import datetime | 4 | import datetime |
| 6 | import logging | 5 | import logging |
| 7 | import json | ||
| 8 | from pathlib import Path | 6 | from pathlib import Path |
| 9 | 7 | ||
| 10 | import torch | 8 | import torch |
| @@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config | |||
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 23 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 24 | from training.optimization import get_one_cycle_schedule |
| 25 | from training.lr import LRFinder | ||
| 27 | from training.ti import patch_trainable_embeddings | 26 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 27 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params |
| 29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
| @@ -173,6 +172,11 @@ def parse_args(): | |||
| 173 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 172 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| 174 | ) | 173 | ) |
| 175 | parser.add_argument( | 174 | parser.add_argument( |
| 175 | "--find_lr", | ||
| 176 | action="store_true", | ||
| 177 | help="Automatically find a learning rate (no training).", | ||
| 178 | ) | ||
| 179 | parser.add_argument( | ||
| 176 | "--learning_rate", | 180 | "--learning_rate", |
| 177 | type=float, | 181 | type=float, |
| 178 | default=1e-4, | 182 | default=1e-4, |
| @@ -225,7 +229,7 @@ def parse_args(): | |||
| 225 | parser.add_argument( | 229 | parser.add_argument( |
| 226 | "--adam_weight_decay", | 230 | "--adam_weight_decay", |
| 227 | type=float, | 231 | type=float, |
| 228 | default=0, | 232 | default=1e-2, |
| 229 | help="Weight decay to use." | 233 | help="Weight decay to use." |
| 230 | ) | 234 | ) |
| 231 | parser.add_argument( | 235 | parser.add_argument( |
| @@ -447,16 +451,23 @@ def main(): | |||
| 447 | global_step_offset = args.global_step | 451 | global_step_offset = args.global_step |
| 448 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 452 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 449 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 453 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 450 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 451 | 454 | ||
| 452 | accelerator = Accelerator( | 455 | if args.find_lr: |
| 453 | log_with=LoggerType.TENSORBOARD, | 456 | accelerator = Accelerator( |
| 454 | logging_dir=f"{basepath}", | 457 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 455 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 458 | mixed_precision=args.mixed_precision |
| 456 | mixed_precision=args.mixed_precision | 459 | ) |
| 457 | ) | 460 | else: |
| 461 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 462 | |||
| 463 | accelerator = Accelerator( | ||
| 464 | log_with=LoggerType.TENSORBOARD, | ||
| 465 | logging_dir=f"{basepath}", | ||
| 466 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 467 | mixed_precision=args.mixed_precision | ||
| 468 | ) | ||
| 458 | 469 | ||
| 459 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 470 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 460 | 471 | ||
| 461 | args.seed = args.seed or (torch.random.seed() >> 32) | 472 | args.seed = args.seed or (torch.random.seed() >> 32) |
| 462 | set_seed(args.seed) | 473 | set_seed(args.seed) |
| @@ -537,6 +548,9 @@ def main(): | |||
| 537 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
| 538 | ) | 549 | ) |
| 539 | 550 | ||
| 551 | if args.find_lr: | ||
| 552 | args.learning_rate = 1e2 | ||
| 553 | |||
| 540 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 541 | if args.use_8bit_adam: | 555 | if args.use_8bit_adam: |
| 542 | try: | 556 | try: |
| @@ -671,7 +685,9 @@ def main(): | |||
| 671 | 685 | ||
| 672 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 686 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps |
| 673 | 687 | ||
| 674 | if args.lr_scheduler == "one_cycle": | 688 | if args.find_lr: |
| 689 | lr_scheduler = None | ||
| 690 | elif args.lr_scheduler == "one_cycle": | ||
| 675 | lr_scheduler = get_one_cycle_schedule( | 691 | lr_scheduler = get_one_cycle_schedule( |
| 676 | optimizer=optimizer, | 692 | optimizer=optimizer, |
| 677 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 693 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| @@ -713,6 +729,63 @@ def main(): | |||
| 713 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 729 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 714 | val_steps = num_val_steps_per_epoch * num_epochs | 730 | val_steps = num_val_steps_per_epoch * num_epochs |
| 715 | 731 | ||
| 732 | def loop(batch): | ||
| 733 | # Convert images to latent space | ||
| 734 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
| 735 | latents = latents * 0.18215 | ||
| 736 | |||
| 737 | # Sample noise that we'll add to the latents | ||
| 738 | noise = torch.randn_like(latents) | ||
| 739 | bsz = latents.shape[0] | ||
| 740 | # Sample a random timestep for each image | ||
| 741 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 742 | (bsz,), device=latents.device) | ||
| 743 | timesteps = timesteps.long() | ||
| 744 | |||
| 745 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 746 | # (this is the forward diffusion process) | ||
| 747 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 748 | |||
| 749 | # Get the text embedding for conditioning | ||
| 750 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 751 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 752 | |||
| 753 | # Predict the noise residual | ||
| 754 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 755 | |||
| 756 | # Get the target for loss depending on the prediction type | ||
| 757 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 758 | target = noise | ||
| 759 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 760 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 761 | else: | ||
| 762 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 763 | |||
| 764 | if args.num_class_images != 0: | ||
| 765 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 766 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 767 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 768 | |||
| 769 | # Compute instance loss | ||
| 770 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 771 | |||
| 772 | # Compute prior loss | ||
| 773 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 774 | |||
| 775 | # Add the prior loss to the instance loss. | ||
| 776 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 777 | else: | ||
| 778 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 779 | |||
| 780 | acc = (model_pred == latents).float().mean() | ||
| 781 | |||
| 782 | return loss, acc, bsz | ||
| 783 | |||
| 784 | if args.find_lr: | ||
| 785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) | ||
| 786 | lr_finder.run() | ||
| 787 | quit() | ||
| 788 | |||
| 716 | # We need to initialize the trackers we use, and also store our configuration. | 789 | # We need to initialize the trackers we use, and also store our configuration. |
| 717 | # The trackers initializes automatically on the main process. | 790 | # The trackers initializes automatically on the main process. |
| 718 | if accelerator.is_main_process: | 791 | if accelerator.is_main_process: |
| @@ -786,54 +859,7 @@ def main(): | |||
| 786 | 859 | ||
| 787 | for step, batch in enumerate(train_dataloader): | 860 | for step, batch in enumerate(train_dataloader): |
| 788 | with accelerator.accumulate(text_encoder): | 861 | with accelerator.accumulate(text_encoder): |
| 789 | # Convert images to latent space | 862 | loss, acc, bsz = loop(batch) |
| 790 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
| 791 | latents = latents * 0.18215 | ||
| 792 | |||
| 793 | # Sample noise that we'll add to the latents | ||
| 794 | noise = torch.randn_like(latents) | ||
| 795 | bsz = latents.shape[0] | ||
| 796 | # Sample a random timestep for each image | ||
| 797 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 798 | (bsz,), device=latents.device) | ||
| 799 | timesteps = timesteps.long() | ||
| 800 | |||
| 801 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 802 | # (this is the forward diffusion process) | ||
| 803 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 804 | |||
| 805 | # Get the text embedding for conditioning | ||
| 806 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 807 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 808 | |||
| 809 | # Predict the noise residual | ||
| 810 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 811 | |||
| 812 | # Get the target for loss depending on the prediction type | ||
| 813 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 814 | target = noise | ||
| 815 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 816 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 817 | else: | ||
| 818 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 819 | |||
| 820 | if args.num_class_images != 0: | ||
| 821 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 822 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 823 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 824 | |||
| 825 | # Compute instance loss | ||
| 826 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 827 | |||
| 828 | # Compute prior loss | ||
| 829 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 830 | |||
| 831 | # Add the prior loss to the instance loss. | ||
| 832 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 833 | else: | ||
| 834 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 835 | |||
| 836 | acc = (model_pred == latents).float().mean() | ||
| 837 | 863 | ||
| 838 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
| 839 | 865 | ||
| @@ -873,33 +899,7 @@ def main(): | |||
| 873 | 899 | ||
| 874 | with torch.inference_mode(): | 900 | with torch.inference_mode(): |
| 875 | for step, batch in enumerate(val_dataloader): | 901 | for step, batch in enumerate(val_dataloader): |
| 876 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 902 | loss, acc, bsz = loop(batch) |
| 877 | latents = latents * 0.18215 | ||
| 878 | |||
| 879 | noise = torch.randn_like(latents) | ||
| 880 | bsz = latents.shape[0] | ||
| 881 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 882 | (bsz,), device=latents.device) | ||
| 883 | timesteps = timesteps.long() | ||
| 884 | |||
| 885 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 886 | |||
| 887 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 888 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 889 | |||
| 890 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 891 | |||
| 892 | # Get the target for loss depending on the prediction type | ||
| 893 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 894 | target = noise | ||
| 895 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 896 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 897 | else: | ||
| 898 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 899 | |||
| 900 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 901 | |||
| 902 | acc = (model_pred == latents).float().mean() | ||
| 903 | 903 | ||
| 904 | avg_loss_val.update(loss.detach_(), bsz) | 904 | avg_loss_val.update(loss.detach_(), bsz) |
| 905 | avg_acc_val.update(acc.detach_(), bsz) | 905 | avg_acc_val.update(acc.detach_(), bsz) |
diff --git a/training/lr.py b/training/lr.py new file mode 100644 index 0000000..dd37baa --- /dev/null +++ b/training/lr.py | |||
| @@ -0,0 +1,115 @@ | |||
| 1 | import numpy as np | ||
| 2 | from torch.optim.lr_scheduler import LambdaLR | ||
| 3 | from tqdm.auto import tqdm | ||
| 4 | import matplotlib.pyplot as plt | ||
| 5 | |||
| 6 | from training.util import AverageMeter | ||
| 7 | |||
| 8 | |||
| 9 | class LRFinder(): | ||
| 10 | def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): | ||
| 11 | self.accelerator = accelerator | ||
| 12 | self.model = model | ||
| 13 | self.optimizer = optimizer | ||
| 14 | self.train_dataloader = train_dataloader | ||
| 15 | self.loss_fn = loss_fn | ||
| 16 | |||
| 17 | def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): | ||
| 18 | best_loss = None | ||
| 19 | lrs = [] | ||
| 20 | losses = [] | ||
| 21 | |||
| 22 | lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) | ||
| 23 | |||
| 24 | progress_bar = tqdm( | ||
| 25 | range(num_epochs * num_steps), | ||
| 26 | disable=not self.accelerator.is_local_main_process, | ||
| 27 | dynamic_ncols=True | ||
| 28 | ) | ||
| 29 | progress_bar.set_description("Epoch X / Y") | ||
| 30 | |||
| 31 | for epoch in range(num_epochs): | ||
| 32 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
| 33 | |||
| 34 | avg_loss = AverageMeter() | ||
| 35 | |||
| 36 | for step, batch in enumerate(self.train_dataloader): | ||
| 37 | with self.accelerator.accumulate(self.model): | ||
| 38 | loss, acc, bsz = self.loss_fn(batch) | ||
| 39 | |||
| 40 | self.accelerator.backward(loss) | ||
| 41 | |||
| 42 | self.optimizer.step() | ||
| 43 | self.optimizer.zero_grad(set_to_none=True) | ||
| 44 | |||
| 45 | avg_loss.update(loss.detach_(), bsz) | ||
| 46 | |||
| 47 | if step >= num_steps: | ||
| 48 | break | ||
| 49 | |||
| 50 | if self.accelerator.sync_gradients: | ||
| 51 | progress_bar.update(1) | ||
| 52 | |||
| 53 | lr_scheduler.step() | ||
| 54 | |||
| 55 | loss = avg_loss.avg.item() | ||
| 56 | if epoch == 0: | ||
| 57 | best_loss = loss | ||
| 58 | else: | ||
| 59 | if smooth_f > 0: | ||
| 60 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | ||
| 61 | if loss < best_loss: | ||
| 62 | best_loss = loss | ||
| 63 | |||
| 64 | lr = lr_scheduler.get_last_lr()[0] | ||
| 65 | |||
| 66 | lrs.append(lr) | ||
| 67 | losses.append(loss) | ||
| 68 | |||
| 69 | progress_bar.set_postfix({ | ||
| 70 | "loss": loss, | ||
| 71 | "best": best_loss, | ||
| 72 | "lr": lr, | ||
| 73 | }) | ||
| 74 | |||
| 75 | if loss > diverge_th * best_loss: | ||
| 76 | print("Stopping early, the loss has diverged") | ||
| 77 | break | ||
| 78 | |||
| 79 | fig, ax = plt.subplots() | ||
| 80 | ax.plot(lrs, losses) | ||
| 81 | |||
| 82 | print("LR suggestion: steepest gradient") | ||
| 83 | min_grad_idx = None | ||
| 84 | try: | ||
| 85 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | ||
| 86 | except ValueError: | ||
| 87 | print( | ||
| 88 | "Failed to compute the gradients, there might not be enough points." | ||
| 89 | ) | ||
| 90 | if min_grad_idx is not None: | ||
| 91 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | ||
| 92 | ax.scatter( | ||
| 93 | lrs[min_grad_idx], | ||
| 94 | losses[min_grad_idx], | ||
| 95 | s=75, | ||
| 96 | marker="o", | ||
| 97 | color="red", | ||
| 98 | zorder=3, | ||
| 99 | label="steepest gradient", | ||
| 100 | ) | ||
| 101 | ax.legend() | ||
| 102 | |||
| 103 | ax.set_xscale("log") | ||
| 104 | ax.set_xlabel("Learning rate") | ||
| 105 | ax.set_ylabel("Loss") | ||
| 106 | |||
| 107 | if fig is not None: | ||
| 108 | plt.show() | ||
| 109 | |||
| 110 | |||
| 111 | def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): | ||
| 112 | def lr_lambda(current_epoch: int): | ||
| 113 | return (current_epoch / num_epochs) ** 5 | ||
| 114 | |||
| 115 | return LambdaLR(optimizer, lr_lambda, last_epoch) | ||
