diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
commit | 403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch) | |
tree | 385c62ef44cc33abc3c5d4b2084c376551137c5f /train_dreambooth.py | |
parent | Don't use vector_dropout by default (diff) | |
download | textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2 textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip |
Fixed reproducibility, more consistant validation
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index df8b54c..6d9bae8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -320,6 +320,12 @@ def parse_args(): | |||
320 | help="Epsilon value for the Adam optimizer" | 320 | help="Epsilon value for the Adam optimizer" |
321 | ) | 321 | ) |
322 | parser.add_argument( | 322 | parser.add_argument( |
323 | "--adam_amsgrad", | ||
324 | type=bool, | ||
325 | default=False, | ||
326 | help="Amsgrad value for the Adam optimizer" | ||
327 | ) | ||
328 | parser.add_argument( | ||
323 | "--mixed_precision", | 329 | "--mixed_precision", |
324 | type=str, | 330 | type=str, |
325 | default="no", | 331 | default="no", |
@@ -642,7 +648,7 @@ def main(): | |||
642 | ) | 648 | ) |
643 | 649 | ||
644 | if args.find_lr: | 650 | if args.find_lr: |
645 | args.learning_rate = 1e-4 | 651 | args.learning_rate = 1e-6 |
646 | 652 | ||
647 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 653 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
648 | if args.use_8bit_adam: | 654 | if args.use_8bit_adam: |
@@ -674,6 +680,7 @@ def main(): | |||
674 | betas=(args.adam_beta1, args.adam_beta2), | 680 | betas=(args.adam_beta1, args.adam_beta2), |
675 | weight_decay=args.adam_weight_decay, | 681 | weight_decay=args.adam_weight_decay, |
676 | eps=args.adam_epsilon, | 682 | eps=args.adam_epsilon, |
683 | amsgrad=args.adam_amsgrad, | ||
677 | ) | 684 | ) |
678 | 685 | ||
679 | weight_dtype = torch.float32 | 686 | weight_dtype = torch.float32 |
@@ -730,6 +737,7 @@ def main(): | |||
730 | template_key=args.train_data_template, | 737 | template_key=args.train_data_template, |
731 | valid_set_size=args.valid_set_size, | 738 | valid_set_size=args.valid_set_size, |
732 | num_workers=args.dataloader_num_workers, | 739 | num_workers=args.dataloader_num_workers, |
740 | seed=args.seed, | ||
733 | filter=keyword_filter, | 741 | filter=keyword_filter, |
734 | collate_fn=collate_fn | 742 | collate_fn=collate_fn |
735 | ) | 743 | ) |
@@ -840,7 +848,7 @@ def main(): | |||
840 | def on_eval(): | 848 | def on_eval(): |
841 | tokenizer.eval() | 849 | tokenizer.eval() |
842 | 850 | ||
843 | def loop(batch): | 851 | def loop(batch, eval: bool = False): |
844 | # Convert images to latent space | 852 | # Convert images to latent space |
845 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
846 | latents = latents * 0.18215 | 854 | latents = latents * 0.18215 |
@@ -849,8 +857,14 @@ def main(): | |||
849 | noise = torch.randn_like(latents) | 857 | noise = torch.randn_like(latents) |
850 | bsz = latents.shape[0] | 858 | bsz = latents.shape[0] |
851 | # Sample a random timestep for each image | 859 | # Sample a random timestep for each image |
852 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None |
853 | (bsz,), device=latents.device) | 861 | timesteps = torch.randint( |
862 | 0, | ||
863 | noise_scheduler.config.num_train_timesteps, | ||
864 | (bsz,), | ||
865 | generator=timesteps_gen, | ||
866 | device=latents.device, | ||
867 | ) | ||
854 | timesteps = timesteps.long() | 868 | timesteps = timesteps.long() |
855 | 869 | ||
856 | # Add noise to the latents according to the noise magnitude at each timestep | 870 | # Add noise to the latents according to the noise magnitude at each timestep |
@@ -1051,7 +1065,7 @@ def main(): | |||
1051 | 1065 | ||
1052 | with torch.inference_mode(): | 1066 | with torch.inference_mode(): |
1053 | for step, batch in enumerate(val_dataloader): | 1067 | for step, batch in enumerate(val_dataloader): |
1054 | loss, acc, bsz = loop(batch) | 1068 | loss, acc, bsz = loop(batch, True) |
1055 | 1069 | ||
1056 | loss = loss.detach_() | 1070 | loss = loss.detach_() |
1057 | acc = acc.detach_() | 1071 | acc = acc.detach_() |