summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 10:32:58 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 10:32:58 +0100
commitbed44095ab99440467c2f302899b970c92baebf8 (patch)
tree2b469fe74e0dc22f0fa38413c69135952363f2af /train_dreambooth.py
parentFixed reproducibility, more consistant validation (diff)
downloadtextual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.gz
textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.bz2
textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.zip
Better eval generator
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6d9bae8..5e6e35d 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -848,7 +848,7 @@ def main():
848 def on_eval(): 848 def on_eval():
849 tokenizer.eval() 849 tokenizer.eval()
850 850
851 def loop(batch, eval: bool = False): 851 def loop(step: int, batch, eval: bool = False):
852 # Convert images to latent space 852 # Convert images to latent space
853 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 853 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
854 latents = latents * 0.18215 854 latents = latents * 0.18215
@@ -857,7 +857,7 @@ def main():
857 noise = torch.randn_like(latents) 857 noise = torch.randn_like(latents)
858 bsz = latents.shape[0] 858 bsz = latents.shape[0]
859 # Sample a random timestep for each image 859 # Sample a random timestep for each image
860 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None 860 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None
861 timesteps = torch.randint( 861 timesteps = torch.randint(
862 0, 862 0,
863 noise_scheduler.config.num_train_timesteps, 863 noise_scheduler.config.num_train_timesteps,
@@ -1008,7 +1008,7 @@ def main():
1008 1008
1009 for step, batch in enumerate(train_dataloader): 1009 for step, batch in enumerate(train_dataloader):
1010 with accelerator.accumulate(unet): 1010 with accelerator.accumulate(unet):
1011 loss, acc, bsz = loop(batch) 1011 loss, acc, bsz = loop(step, batch)
1012 1012
1013 accelerator.backward(loss) 1013 accelerator.backward(loss)
1014 1014
@@ -1065,7 +1065,7 @@ def main():
1065 1065
1066 with torch.inference_mode(): 1066 with torch.inference_mode():
1067 for step, batch in enumerate(val_dataloader): 1067 for step, batch in enumerate(val_dataloader):
1068 loss, acc, bsz = loop(batch, True) 1068 loss, acc, bsz = loop(step, batch, True)
1069 1069
1070 loss = loss.detach_() 1070 loss = loss.detach_()
1071 acc = acc.detach_() 1071 acc = acc.detach_()