summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 11:02:49 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 11:02:49 +0100
commit9d6c75262b6919758e781b8333428861a5bf7ede (patch)
tree72e5814413c18d476813867d87c8360c14aee200 /train_ti.py
parentSet default dimensions to 768; add config inheritance (diff)
downloadtextual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz
textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2
textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip
Added learning rate finder
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py174
1 files changed, 87 insertions, 87 deletions
diff --git a/train_ti.py b/train_ti.py
index 6e30ac3..ab00b60 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,10 +1,8 @@
1import argparse 1import argparse
2import itertools 2import itertools
3import math 3import math
4import os
5import datetime 4import datetime
6import logging 5import logging
7import json
8from pathlib import Path 6from pathlib import Path
9 7
10import torch 8import torch
@@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 23from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 24from training.optimization import get_one_cycle_schedule
25from training.lr import LRFinder
27from training.ti import patch_trainable_embeddings 26from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params 27from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params
29from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
@@ -173,6 +172,11 @@ def parse_args():
173 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 172 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
174 ) 173 )
175 parser.add_argument( 174 parser.add_argument(
175 "--find_lr",
176 action="store_true",
177 help="Automatically find a learning rate (no training).",
178 )
179 parser.add_argument(
176 "--learning_rate", 180 "--learning_rate",
177 type=float, 181 type=float,
178 default=1e-4, 182 default=1e-4,
@@ -225,7 +229,7 @@ def parse_args():
225 parser.add_argument( 229 parser.add_argument(
226 "--adam_weight_decay", 230 "--adam_weight_decay",
227 type=float, 231 type=float,
228 default=0, 232 default=1e-2,
229 help="Weight decay to use." 233 help="Weight decay to use."
230 ) 234 )
231 parser.add_argument( 235 parser.add_argument(
@@ -447,16 +451,23 @@ def main():
447 global_step_offset = args.global_step 451 global_step_offset = args.global_step
448 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 452 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
449 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 453 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
450 basepath.mkdir(parents=True, exist_ok=True)
451 454
452 accelerator = Accelerator( 455 if args.find_lr:
453 log_with=LoggerType.TENSORBOARD, 456 accelerator = Accelerator(
454 logging_dir=f"{basepath}", 457 gradient_accumulation_steps=args.gradient_accumulation_steps,
455 gradient_accumulation_steps=args.gradient_accumulation_steps, 458 mixed_precision=args.mixed_precision
456 mixed_precision=args.mixed_precision 459 )
457 ) 460 else:
461 basepath.mkdir(parents=True, exist_ok=True)
458 462
459 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 463 accelerator = Accelerator(
464 log_with=LoggerType.TENSORBOARD,
465 logging_dir=f"{basepath}",
466 gradient_accumulation_steps=args.gradient_accumulation_steps,
467 mixed_precision=args.mixed_precision
468 )
469
470 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
460 471
461 args.seed = args.seed or (torch.random.seed() >> 32) 472 args.seed = args.seed or (torch.random.seed() >> 32)
462 set_seed(args.seed) 473 set_seed(args.seed)
@@ -537,6 +548,9 @@ def main():
537 args.train_batch_size * accelerator.num_processes 548 args.train_batch_size * accelerator.num_processes
538 ) 549 )
539 550
551 if args.find_lr:
552 args.learning_rate = 1e2
553
540 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 554 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
541 if args.use_8bit_adam: 555 if args.use_8bit_adam:
542 try: 556 try:
@@ -671,7 +685,9 @@ def main():
671 685
672 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps 686 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
673 687
674 if args.lr_scheduler == "one_cycle": 688 if args.find_lr:
689 lr_scheduler = None
690 elif args.lr_scheduler == "one_cycle":
675 lr_scheduler = get_one_cycle_schedule( 691 lr_scheduler = get_one_cycle_schedule(
676 optimizer=optimizer, 692 optimizer=optimizer,
677 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 693 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
@@ -713,6 +729,63 @@ def main():
713 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 729 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
714 val_steps = num_val_steps_per_epoch * num_epochs 730 val_steps = num_val_steps_per_epoch * num_epochs
715 731
732 def loop(batch):
733 # Convert images to latent space
734 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
735 latents = latents * 0.18215
736
737 # Sample noise that we'll add to the latents
738 noise = torch.randn_like(latents)
739 bsz = latents.shape[0]
740 # Sample a random timestep for each image
741 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
742 (bsz,), device=latents.device)
743 timesteps = timesteps.long()
744
745 # Add noise to the latents according to the noise magnitude at each timestep
746 # (this is the forward diffusion process)
747 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
748
749 # Get the text embedding for conditioning
750 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
751 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
752
753 # Predict the noise residual
754 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
755
756 # Get the target for loss depending on the prediction type
757 if noise_scheduler.config.prediction_type == "epsilon":
758 target = noise
759 elif noise_scheduler.config.prediction_type == "v_prediction":
760 target = noise_scheduler.get_velocity(latents, noise, timesteps)
761 else:
762 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
763
764 if args.num_class_images != 0:
765 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
766 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
767 target, target_prior = torch.chunk(target, 2, dim=0)
768
769 # Compute instance loss
770 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
771
772 # Compute prior loss
773 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
774
775 # Add the prior loss to the instance loss.
776 loss = loss + args.prior_loss_weight * prior_loss
777 else:
778 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
779
780 acc = (model_pred == latents).float().mean()
781
782 return loss, acc, bsz
783
784 if args.find_lr:
785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop)
786 lr_finder.run()
787 quit()
788
716 # We need to initialize the trackers we use, and also store our configuration. 789 # We need to initialize the trackers we use, and also store our configuration.
717 # The trackers initializes automatically on the main process. 790 # The trackers initializes automatically on the main process.
718 if accelerator.is_main_process: 791 if accelerator.is_main_process:
@@ -786,54 +859,7 @@ def main():
786 859
787 for step, batch in enumerate(train_dataloader): 860 for step, batch in enumerate(train_dataloader):
788 with accelerator.accumulate(text_encoder): 861 with accelerator.accumulate(text_encoder):
789 # Convert images to latent space 862 loss, acc, bsz = loop(batch)
790 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
791 latents = latents * 0.18215
792
793 # Sample noise that we'll add to the latents
794 noise = torch.randn_like(latents)
795 bsz = latents.shape[0]
796 # Sample a random timestep for each image
797 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
798 (bsz,), device=latents.device)
799 timesteps = timesteps.long()
800
801 # Add noise to the latents according to the noise magnitude at each timestep
802 # (this is the forward diffusion process)
803 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
804
805 # Get the text embedding for conditioning
806 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
807 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
808
809 # Predict the noise residual
810 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
811
812 # Get the target for loss depending on the prediction type
813 if noise_scheduler.config.prediction_type == "epsilon":
814 target = noise
815 elif noise_scheduler.config.prediction_type == "v_prediction":
816 target = noise_scheduler.get_velocity(latents, noise, timesteps)
817 else:
818 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
819
820 if args.num_class_images != 0:
821 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
822 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
823 target, target_prior = torch.chunk(target, 2, dim=0)
824
825 # Compute instance loss
826 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
827
828 # Compute prior loss
829 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
830
831 # Add the prior loss to the instance loss.
832 loss = loss + args.prior_loss_weight * prior_loss
833 else:
834 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
835
836 acc = (model_pred == latents).float().mean()
837 863
838 accelerator.backward(loss) 864 accelerator.backward(loss)
839 865
@@ -873,33 +899,7 @@ def main():
873 899
874 with torch.inference_mode(): 900 with torch.inference_mode():
875 for step, batch in enumerate(val_dataloader): 901 for step, batch in enumerate(val_dataloader):
876 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 902 loss, acc, bsz = loop(batch)
877 latents = latents * 0.18215
878
879 noise = torch.randn_like(latents)
880 bsz = latents.shape[0]
881 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
882 (bsz,), device=latents.device)
883 timesteps = timesteps.long()
884
885 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
886
887 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
888 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
889
890 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
891
892 # Get the target for loss depending on the prediction type
893 if noise_scheduler.config.prediction_type == "epsilon":
894 target = noise
895 elif noise_scheduler.config.prediction_type == "v_prediction":
896 target = noise_scheduler.get_velocity(latents, noise, timesteps)
897 else:
898 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
899
900 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
901
902 acc = (model_pred == latents).float().mean()
903 903
904 avg_loss_val.update(loss.detach_(), bsz) 904 avg_loss_val.update(loss.detach_(), bsz)
905 avg_acc_val.update(acc.detach_(), bsz) 905 avg_acc_val.update(acc.detach_(), bsz)