summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
commit403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch)
tree385c62ef44cc33abc3c5d4b2084c376551137c5f /train_dreambooth.py
parentDon't use vector_dropout by default (diff)
downloadtextual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip
Fixed reproducibility, more consistant validation
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index df8b54c..6d9bae8 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -320,6 +320,12 @@ def parse_args():
320 help="Epsilon value for the Adam optimizer" 320 help="Epsilon value for the Adam optimizer"
321 ) 321 )
322 parser.add_argument( 322 parser.add_argument(
323 "--adam_amsgrad",
324 type=bool,
325 default=False,
326 help="Amsgrad value for the Adam optimizer"
327 )
328 parser.add_argument(
323 "--mixed_precision", 329 "--mixed_precision",
324 type=str, 330 type=str,
325 default="no", 331 default="no",
@@ -642,7 +648,7 @@ def main():
642 ) 648 )
643 649
644 if args.find_lr: 650 if args.find_lr:
645 args.learning_rate = 1e-4 651 args.learning_rate = 1e-6
646 652
647 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 653 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
648 if args.use_8bit_adam: 654 if args.use_8bit_adam:
@@ -674,6 +680,7 @@ def main():
674 betas=(args.adam_beta1, args.adam_beta2), 680 betas=(args.adam_beta1, args.adam_beta2),
675 weight_decay=args.adam_weight_decay, 681 weight_decay=args.adam_weight_decay,
676 eps=args.adam_epsilon, 682 eps=args.adam_epsilon,
683 amsgrad=args.adam_amsgrad,
677 ) 684 )
678 685
679 weight_dtype = torch.float32 686 weight_dtype = torch.float32
@@ -730,6 +737,7 @@ def main():
730 template_key=args.train_data_template, 737 template_key=args.train_data_template,
731 valid_set_size=args.valid_set_size, 738 valid_set_size=args.valid_set_size,
732 num_workers=args.dataloader_num_workers, 739 num_workers=args.dataloader_num_workers,
740 seed=args.seed,
733 filter=keyword_filter, 741 filter=keyword_filter,
734 collate_fn=collate_fn 742 collate_fn=collate_fn
735 ) 743 )
@@ -840,7 +848,7 @@ def main():
840 def on_eval(): 848 def on_eval():
841 tokenizer.eval() 849 tokenizer.eval()
842 850
843 def loop(batch): 851 def loop(batch, eval: bool = False):
844 # Convert images to latent space 852 # Convert images to latent space
845 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 853 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
846 latents = latents * 0.18215 854 latents = latents * 0.18215
@@ -849,8 +857,14 @@ def main():
849 noise = torch.randn_like(latents) 857 noise = torch.randn_like(latents)
850 bsz = latents.shape[0] 858 bsz = latents.shape[0]
851 # Sample a random timestep for each image 859 # Sample a random timestep for each image
852 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 860 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None
853 (bsz,), device=latents.device) 861 timesteps = torch.randint(
862 0,
863 noise_scheduler.config.num_train_timesteps,
864 (bsz,),
865 generator=timesteps_gen,
866 device=latents.device,
867 )
854 timesteps = timesteps.long() 868 timesteps = timesteps.long()
855 869
856 # Add noise to the latents according to the noise magnitude at each timestep 870 # Add noise to the latents according to the noise magnitude at each timestep
@@ -1051,7 +1065,7 @@ def main():
1051 1065
1052 with torch.inference_mode(): 1066 with torch.inference_mode():
1053 for step, batch in enumerate(val_dataloader): 1067 for step, batch in enumerate(val_dataloader):
1054 loss, acc, bsz = loop(batch) 1068 loss, acc, bsz = loop(batch, True)
1055 1069
1056 loss = loss.detach_() 1070 loss = loss.detach_()
1057 acc = acc.detach_() 1071 acc = acc.detach_()