summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
commit403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch)
tree385c62ef44cc33abc3c5d4b2084c376551137c5f /train_ti.py
parentDon't use vector_dropout by default (diff)
downloadtextual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip
Fixed reproducibility, more consistant validation
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index 1685dc4..5d6eafc 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -289,6 +289,12 @@ def parse_args():
289 help="Epsilon value for the Adam optimizer" 289 help="Epsilon value for the Adam optimizer"
290 ) 290 )
291 parser.add_argument( 291 parser.add_argument(
292 "--adam_amsgrad",
293 type=bool,
294 default=False,
295 help="Amsgrad value for the Adam optimizer"
296 )
297 parser.add_argument(
292 "--mixed_precision", 298 "--mixed_precision",
293 type=str, 299 type=str,
294 default="no", 300 default="no",
@@ -592,7 +598,7 @@ def main():
592 ) 598 )
593 599
594 if args.find_lr: 600 if args.find_lr:
595 args.learning_rate = 1e-4 601 args.learning_rate = 1e-6
596 602
597 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 603 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
598 if args.use_8bit_adam: 604 if args.use_8bit_adam:
@@ -612,6 +618,7 @@ def main():
612 betas=(args.adam_beta1, args.adam_beta2), 618 betas=(args.adam_beta1, args.adam_beta2),
613 weight_decay=args.adam_weight_decay, 619 weight_decay=args.adam_weight_decay,
614 eps=args.adam_epsilon, 620 eps=args.adam_epsilon,
621 amsgrad=args.adam_amsgrad,
615 ) 622 )
616 623
617 weight_dtype = torch.float32 624 weight_dtype = torch.float32
@@ -673,6 +680,7 @@ def main():
673 template_key=args.train_data_template, 680 template_key=args.train_data_template,
674 valid_set_size=args.valid_set_size, 681 valid_set_size=args.valid_set_size,
675 num_workers=args.dataloader_num_workers, 682 num_workers=args.dataloader_num_workers,
683 seed=args.seed,
676 filter=keyword_filter, 684 filter=keyword_filter,
677 collate_fn=collate_fn 685 collate_fn=collate_fn
678 ) 686 )
@@ -791,7 +799,7 @@ def main():
791 def on_eval(): 799 def on_eval():
792 tokenizer.eval() 800 tokenizer.eval()
793 801
794 def loop(batch): 802 def loop(batch, eval: bool = False):
795 # Convert images to latent space 803 # Convert images to latent space
796 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 804 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
797 latents = latents * 0.18215 805 latents = latents * 0.18215
@@ -800,8 +808,14 @@ def main():
800 noise = torch.randn_like(latents) 808 noise = torch.randn_like(latents)
801 bsz = latents.shape[0] 809 bsz = latents.shape[0]
802 # Sample a random timestep for each image 810 # Sample a random timestep for each image
803 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 811 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None
804 (bsz,), device=latents.device) 812 timesteps = torch.randint(
813 0,
814 noise_scheduler.config.num_train_timesteps,
815 (bsz,),
816 generator=timesteps_gen,
817 device=latents.device,
818 )
805 timesteps = timesteps.long() 819 timesteps = timesteps.long()
806 820
807 # Add noise to the latents according to the noise magnitude at each timestep 821 # Add noise to the latents according to the noise magnitude at each timestep
@@ -984,7 +998,7 @@ def main():
984 998
985 with torch.inference_mode(): 999 with torch.inference_mode():
986 for step, batch in enumerate(val_dataloader): 1000 for step, batch in enumerate(val_dataloader):
987 loss, acc, bsz = loop(batch) 1001 loss, acc, bsz = loop(batch, True)
988 1002
989 loss = loss.detach_() 1003 loss = loss.detach_()
990 acc = acc.detach_() 1004 acc = acc.detach_()