diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
commit | bed44095ab99440467c2f302899b970c92baebf8 (patch) | |
tree | 2b469fe74e0dc22f0fa38413c69135952363f2af /train_dreambooth.py | |
parent | Fixed reproducibility, more consistant validation (diff) | |
download | textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.gz textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.bz2 textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.zip |
Better eval generator
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d9bae8..5e6e35d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -848,7 +848,7 @@ def main(): | |||
848 | def on_eval(): | 848 | def on_eval(): |
849 | tokenizer.eval() | 849 | tokenizer.eval() |
850 | 850 | ||
851 | def loop(batch, eval: bool = False): | 851 | def loop(step: int, batch, eval: bool = False): |
852 | # Convert images to latent space | 852 | # Convert images to latent space |
853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
854 | latents = latents * 0.18215 | 854 | latents = latents * 0.18215 |
@@ -857,7 +857,7 @@ def main(): | |||
857 | noise = torch.randn_like(latents) | 857 | noise = torch.randn_like(latents) |
858 | bsz = latents.shape[0] | 858 | bsz = latents.shape[0] |
859 | # Sample a random timestep for each image | 859 | # Sample a random timestep for each image |
860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
861 | timesteps = torch.randint( | 861 | timesteps = torch.randint( |
862 | 0, | 862 | 0, |
863 | noise_scheduler.config.num_train_timesteps, | 863 | noise_scheduler.config.num_train_timesteps, |
@@ -1008,7 +1008,7 @@ def main(): | |||
1008 | 1008 | ||
1009 | for step, batch in enumerate(train_dataloader): | 1009 | for step, batch in enumerate(train_dataloader): |
1010 | with accelerator.accumulate(unet): | 1010 | with accelerator.accumulate(unet): |
1011 | loss, acc, bsz = loop(batch) | 1011 | loss, acc, bsz = loop(step, batch) |
1012 | 1012 | ||
1013 | accelerator.backward(loss) | 1013 | accelerator.backward(loss) |
1014 | 1014 | ||
@@ -1065,7 +1065,7 @@ def main(): | |||
1065 | 1065 | ||
1066 | with torch.inference_mode(): | 1066 | with torch.inference_mode(): |
1067 | for step, batch in enumerate(val_dataloader): | 1067 | for step, batch in enumerate(val_dataloader): |
1068 | loss, acc, bsz = loop(batch, True) | 1068 | loss, acc, bsz = loop(step, batch, True) |
1069 | 1069 | ||
1070 | loss = loss.detach_() | 1070 | loss = loss.detach_() |
1071 | acc = acc.detach_() | 1071 | acc = acc.detach_() |