diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
| commit | bed44095ab99440467c2f302899b970c92baebf8 (patch) | |
| tree | 2b469fe74e0dc22f0fa38413c69135952363f2af /train_ti.py | |
| parent | Fixed reproducibility, more consistant validation (diff) | |
| download | textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.gz textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.bz2 textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.zip | |
Better eval generator
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 5d6eafc..6f116c3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -598,7 +598,7 @@ def main(): | |||
| 598 | ) | 598 | ) |
| 599 | 599 | ||
| 600 | if args.find_lr: | 600 | if args.find_lr: |
| 601 | args.learning_rate = 1e-6 | 601 | args.learning_rate = 1e-5 |
| 602 | 602 | ||
| 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 604 | if args.use_8bit_adam: | 604 | if args.use_8bit_adam: |
| @@ -799,7 +799,7 @@ def main(): | |||
| 799 | def on_eval(): | 799 | def on_eval(): |
| 800 | tokenizer.eval() | 800 | tokenizer.eval() |
| 801 | 801 | ||
| 802 | def loop(batch, eval: bool = False): | 802 | def loop(step: int, batch, eval: bool = False): |
| 803 | # Convert images to latent space | 803 | # Convert images to latent space |
| 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| 805 | latents = latents * 0.18215 | 805 | latents = latents * 0.18215 |
| @@ -808,7 +808,7 @@ def main(): | |||
| 808 | noise = torch.randn_like(latents) | 808 | noise = torch.randn_like(latents) |
| 809 | bsz = latents.shape[0] | 809 | bsz = latents.shape[0] |
| 810 | # Sample a random timestep for each image | 810 | # Sample a random timestep for each image |
| 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
| 812 | timesteps = torch.randint( | 812 | timesteps = torch.randint( |
| 813 | 0, | 813 | 0, |
| 814 | noise_scheduler.config.num_train_timesteps, | 814 | noise_scheduler.config.num_train_timesteps, |
| @@ -881,7 +881,7 @@ def main(): | |||
| 881 | on_train=on_train, | 881 | on_train=on_train, |
| 882 | on_eval=on_eval, | 882 | on_eval=on_eval, |
| 883 | ) | 883 | ) |
| 884 | lr_finder.run(end_lr=1e2) | 884 | lr_finder.run(end_lr=1e3) |
| 885 | 885 | ||
| 886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 887 | plt.close() | 887 | plt.close() |
| @@ -954,7 +954,7 @@ def main(): | |||
| 954 | 954 | ||
| 955 | for step, batch in enumerate(train_dataloader): | 955 | for step, batch in enumerate(train_dataloader): |
| 956 | with accelerator.accumulate(text_encoder): | 956 | with accelerator.accumulate(text_encoder): |
| 957 | loss, acc, bsz = loop(batch) | 957 | loss, acc, bsz = loop(step, batch) |
| 958 | 958 | ||
| 959 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
| 960 | 960 | ||
| @@ -998,7 +998,7 @@ def main(): | |||
| 998 | 998 | ||
| 999 | with torch.inference_mode(): | 999 | with torch.inference_mode(): |
| 1000 | for step, batch in enumerate(val_dataloader): | 1000 | for step, batch in enumerate(val_dataloader): |
| 1001 | loss, acc, bsz = loop(batch, True) | 1001 | loss, acc, bsz = loop(step, batch, True) |
| 1002 | 1002 | ||
| 1003 | loss = loss.detach_() | 1003 | loss = loss.detach_() |
| 1004 | acc = acc.detach_() | 1004 | acc = acc.detach_() |
