diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
commit | bed44095ab99440467c2f302899b970c92baebf8 (patch) | |
tree | 2b469fe74e0dc22f0fa38413c69135952363f2af /train_ti.py | |
parent | Fixed reproducibility, more consistant validation (diff) | |
download | textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.gz textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.bz2 textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.zip |
Better eval generator
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 5d6eafc..6f116c3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -598,7 +598,7 @@ def main(): | |||
598 | ) | 598 | ) |
599 | 599 | ||
600 | if args.find_lr: | 600 | if args.find_lr: |
601 | args.learning_rate = 1e-6 | 601 | args.learning_rate = 1e-5 |
602 | 602 | ||
603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
604 | if args.use_8bit_adam: | 604 | if args.use_8bit_adam: |
@@ -799,7 +799,7 @@ def main(): | |||
799 | def on_eval(): | 799 | def on_eval(): |
800 | tokenizer.eval() | 800 | tokenizer.eval() |
801 | 801 | ||
802 | def loop(batch, eval: bool = False): | 802 | def loop(step: int, batch, eval: bool = False): |
803 | # Convert images to latent space | 803 | # Convert images to latent space |
804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
805 | latents = latents * 0.18215 | 805 | latents = latents * 0.18215 |
@@ -808,7 +808,7 @@ def main(): | |||
808 | noise = torch.randn_like(latents) | 808 | noise = torch.randn_like(latents) |
809 | bsz = latents.shape[0] | 809 | bsz = latents.shape[0] |
810 | # Sample a random timestep for each image | 810 | # Sample a random timestep for each image |
811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
812 | timesteps = torch.randint( | 812 | timesteps = torch.randint( |
813 | 0, | 813 | 0, |
814 | noise_scheduler.config.num_train_timesteps, | 814 | noise_scheduler.config.num_train_timesteps, |
@@ -881,7 +881,7 @@ def main(): | |||
881 | on_train=on_train, | 881 | on_train=on_train, |
882 | on_eval=on_eval, | 882 | on_eval=on_eval, |
883 | ) | 883 | ) |
884 | lr_finder.run(end_lr=1e2) | 884 | lr_finder.run(end_lr=1e3) |
885 | 885 | ||
886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
887 | plt.close() | 887 | plt.close() |
@@ -954,7 +954,7 @@ def main(): | |||
954 | 954 | ||
955 | for step, batch in enumerate(train_dataloader): | 955 | for step, batch in enumerate(train_dataloader): |
956 | with accelerator.accumulate(text_encoder): | 956 | with accelerator.accumulate(text_encoder): |
957 | loss, acc, bsz = loop(batch) | 957 | loss, acc, bsz = loop(step, batch) |
958 | 958 | ||
959 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
960 | 960 | ||
@@ -998,7 +998,7 @@ def main(): | |||
998 | 998 | ||
999 | with torch.inference_mode(): | 999 | with torch.inference_mode(): |
1000 | for step, batch in enumerate(val_dataloader): | 1000 | for step, batch in enumerate(val_dataloader): |
1001 | loss, acc, bsz = loop(batch, True) | 1001 | loss, acc, bsz = loop(step, batch, True) |
1002 | 1002 | ||
1003 | loss = loss.detach_() | 1003 | loss = loss.detach_() |
1004 | acc = acc.detach_() | 1004 | acc = acc.detach_() |