diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
commit | 403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch) | |
tree | 385c62ef44cc33abc3c5d4b2084c376551137c5f /train_ti.py | |
parent | Don't use vector_dropout by default (diff) | |
download | textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2 textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip |
Fixed reproducibility, more consistant validation
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index 1685dc4..5d6eafc 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -289,6 +289,12 @@ def parse_args(): | |||
289 | help="Epsilon value for the Adam optimizer" | 289 | help="Epsilon value for the Adam optimizer" |
290 | ) | 290 | ) |
291 | parser.add_argument( | 291 | parser.add_argument( |
292 | "--adam_amsgrad", | ||
293 | type=bool, | ||
294 | default=False, | ||
295 | help="Amsgrad value for the Adam optimizer" | ||
296 | ) | ||
297 | parser.add_argument( | ||
292 | "--mixed_precision", | 298 | "--mixed_precision", |
293 | type=str, | 299 | type=str, |
294 | default="no", | 300 | default="no", |
@@ -592,7 +598,7 @@ def main(): | |||
592 | ) | 598 | ) |
593 | 599 | ||
594 | if args.find_lr: | 600 | if args.find_lr: |
595 | args.learning_rate = 1e-4 | 601 | args.learning_rate = 1e-6 |
596 | 602 | ||
597 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
598 | if args.use_8bit_adam: | 604 | if args.use_8bit_adam: |
@@ -612,6 +618,7 @@ def main(): | |||
612 | betas=(args.adam_beta1, args.adam_beta2), | 618 | betas=(args.adam_beta1, args.adam_beta2), |
613 | weight_decay=args.adam_weight_decay, | 619 | weight_decay=args.adam_weight_decay, |
614 | eps=args.adam_epsilon, | 620 | eps=args.adam_epsilon, |
621 | amsgrad=args.adam_amsgrad, | ||
615 | ) | 622 | ) |
616 | 623 | ||
617 | weight_dtype = torch.float32 | 624 | weight_dtype = torch.float32 |
@@ -673,6 +680,7 @@ def main(): | |||
673 | template_key=args.train_data_template, | 680 | template_key=args.train_data_template, |
674 | valid_set_size=args.valid_set_size, | 681 | valid_set_size=args.valid_set_size, |
675 | num_workers=args.dataloader_num_workers, | 682 | num_workers=args.dataloader_num_workers, |
683 | seed=args.seed, | ||
676 | filter=keyword_filter, | 684 | filter=keyword_filter, |
677 | collate_fn=collate_fn | 685 | collate_fn=collate_fn |
678 | ) | 686 | ) |
@@ -791,7 +799,7 @@ def main(): | |||
791 | def on_eval(): | 799 | def on_eval(): |
792 | tokenizer.eval() | 800 | tokenizer.eval() |
793 | 801 | ||
794 | def loop(batch): | 802 | def loop(batch, eval: bool = False): |
795 | # Convert images to latent space | 803 | # Convert images to latent space |
796 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
797 | latents = latents * 0.18215 | 805 | latents = latents * 0.18215 |
@@ -800,8 +808,14 @@ def main(): | |||
800 | noise = torch.randn_like(latents) | 808 | noise = torch.randn_like(latents) |
801 | bsz = latents.shape[0] | 809 | bsz = latents.shape[0] |
802 | # Sample a random timestep for each image | 810 | # Sample a random timestep for each image |
803 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None |
804 | (bsz,), device=latents.device) | 812 | timesteps = torch.randint( |
813 | 0, | ||
814 | noise_scheduler.config.num_train_timesteps, | ||
815 | (bsz,), | ||
816 | generator=timesteps_gen, | ||
817 | device=latents.device, | ||
818 | ) | ||
805 | timesteps = timesteps.long() | 819 | timesteps = timesteps.long() |
806 | 820 | ||
807 | # Add noise to the latents according to the noise magnitude at each timestep | 821 | # Add noise to the latents according to the noise magnitude at each timestep |
@@ -984,7 +998,7 @@ def main(): | |||
984 | 998 | ||
985 | with torch.inference_mode(): | 999 | with torch.inference_mode(): |
986 | for step, batch in enumerate(val_dataloader): | 1000 | for step, batch in enumerate(val_dataloader): |
987 | loss, acc, bsz = loop(batch) | 1001 | loss, acc, bsz = loop(batch, True) |
988 | 1002 | ||
989 | loss = loss.detach_() | 1003 | loss = loss.detach_() |
990 | acc = acc.detach_() | 1004 | acc = acc.detach_() |