diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 10:32:58 +0100 |
| commit | bed44095ab99440467c2f302899b970c92baebf8 (patch) | |
| tree | 2b469fe74e0dc22f0fa38413c69135952363f2af | |
| parent | Fixed reproducibility, more consistant validation (diff) | |
| download | textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.gz textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.bz2 textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.zip | |
Better eval generator
| -rw-r--r-- | train_dreambooth.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 12 | ||||
| -rw-r--r-- | training/lr.py | 6 |
3 files changed, 13 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d9bae8..5e6e35d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -848,7 +848,7 @@ def main(): | |||
| 848 | def on_eval(): | 848 | def on_eval(): |
| 849 | tokenizer.eval() | 849 | tokenizer.eval() |
| 850 | 850 | ||
| 851 | def loop(batch, eval: bool = False): | 851 | def loop(step: int, batch, eval: bool = False): |
| 852 | # Convert images to latent space | 852 | # Convert images to latent space |
| 853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 854 | latents = latents * 0.18215 | 854 | latents = latents * 0.18215 |
| @@ -857,7 +857,7 @@ def main(): | |||
| 857 | noise = torch.randn_like(latents) | 857 | noise = torch.randn_like(latents) |
| 858 | bsz = latents.shape[0] | 858 | bsz = latents.shape[0] |
| 859 | # Sample a random timestep for each image | 859 | # Sample a random timestep for each image |
| 860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
| 861 | timesteps = torch.randint( | 861 | timesteps = torch.randint( |
| 862 | 0, | 862 | 0, |
| 863 | noise_scheduler.config.num_train_timesteps, | 863 | noise_scheduler.config.num_train_timesteps, |
| @@ -1008,7 +1008,7 @@ def main(): | |||
| 1008 | 1008 | ||
| 1009 | for step, batch in enumerate(train_dataloader): | 1009 | for step, batch in enumerate(train_dataloader): |
| 1010 | with accelerator.accumulate(unet): | 1010 | with accelerator.accumulate(unet): |
| 1011 | loss, acc, bsz = loop(batch) | 1011 | loss, acc, bsz = loop(step, batch) |
| 1012 | 1012 | ||
| 1013 | accelerator.backward(loss) | 1013 | accelerator.backward(loss) |
| 1014 | 1014 | ||
| @@ -1065,7 +1065,7 @@ def main(): | |||
| 1065 | 1065 | ||
| 1066 | with torch.inference_mode(): | 1066 | with torch.inference_mode(): |
| 1067 | for step, batch in enumerate(val_dataloader): | 1067 | for step, batch in enumerate(val_dataloader): |
| 1068 | loss, acc, bsz = loop(batch, True) | 1068 | loss, acc, bsz = loop(step, batch, True) |
| 1069 | 1069 | ||
| 1070 | loss = loss.detach_() | 1070 | loss = loss.detach_() |
| 1071 | acc = acc.detach_() | 1071 | acc = acc.detach_() |
diff --git a/train_ti.py b/train_ti.py index 5d6eafc..6f116c3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -598,7 +598,7 @@ def main(): | |||
| 598 | ) | 598 | ) |
| 599 | 599 | ||
| 600 | if args.find_lr: | 600 | if args.find_lr: |
| 601 | args.learning_rate = 1e-6 | 601 | args.learning_rate = 1e-5 |
| 602 | 602 | ||
| 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 604 | if args.use_8bit_adam: | 604 | if args.use_8bit_adam: |
| @@ -799,7 +799,7 @@ def main(): | |||
| 799 | def on_eval(): | 799 | def on_eval(): |
| 800 | tokenizer.eval() | 800 | tokenizer.eval() |
| 801 | 801 | ||
| 802 | def loop(batch, eval: bool = False): | 802 | def loop(step: int, batch, eval: bool = False): |
| 803 | # Convert images to latent space | 803 | # Convert images to latent space |
| 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| 805 | latents = latents * 0.18215 | 805 | latents = latents * 0.18215 |
| @@ -808,7 +808,7 @@ def main(): | |||
| 808 | noise = torch.randn_like(latents) | 808 | noise = torch.randn_like(latents) |
| 809 | bsz = latents.shape[0] | 809 | bsz = latents.shape[0] |
| 810 | # Sample a random timestep for each image | 810 | # Sample a random timestep for each image |
| 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
| 812 | timesteps = torch.randint( | 812 | timesteps = torch.randint( |
| 813 | 0, | 813 | 0, |
| 814 | noise_scheduler.config.num_train_timesteps, | 814 | noise_scheduler.config.num_train_timesteps, |
| @@ -881,7 +881,7 @@ def main(): | |||
| 881 | on_train=on_train, | 881 | on_train=on_train, |
| 882 | on_eval=on_eval, | 882 | on_eval=on_eval, |
| 883 | ) | 883 | ) |
| 884 | lr_finder.run(end_lr=1e2) | 884 | lr_finder.run(end_lr=1e3) |
| 885 | 885 | ||
| 886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 887 | plt.close() | 887 | plt.close() |
| @@ -954,7 +954,7 @@ def main(): | |||
| 954 | 954 | ||
| 955 | for step, batch in enumerate(train_dataloader): | 955 | for step, batch in enumerate(train_dataloader): |
| 956 | with accelerator.accumulate(text_encoder): | 956 | with accelerator.accumulate(text_encoder): |
| 957 | loss, acc, bsz = loop(batch) | 957 | loss, acc, bsz = loop(step, batch) |
| 958 | 958 | ||
| 959 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
| 960 | 960 | ||
| @@ -998,7 +998,7 @@ def main(): | |||
| 998 | 998 | ||
| 999 | with torch.inference_mode(): | 999 | with torch.inference_mode(): |
| 1000 | for step, batch in enumerate(val_dataloader): | 1000 | for step, batch in enumerate(val_dataloader): |
| 1001 | loss, acc, bsz = loop(batch, True) | 1001 | loss, acc, bsz = loop(step, batch, True) |
| 1002 | 1002 | ||
| 1003 | loss = loss.detach_() | 1003 | loss = loss.detach_() |
| 1004 | acc = acc.detach_() | 1004 | acc = acc.detach_() |
diff --git a/training/lr.py b/training/lr.py index a3144ba..c8dc040 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -24,7 +24,7 @@ class LRFinder(): | |||
| 24 | optimizer, | 24 | optimizer, |
| 25 | train_dataloader, | 25 | train_dataloader, |
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
| 29 | on_eval: Callable[[], None] = noop | 29 | on_eval: Callable[[], None] = noop |
| 30 | ): | 30 | ): |
| @@ -89,7 +89,7 @@ class LRFinder(): | |||
| 89 | break | 89 | break |
| 90 | 90 | ||
| 91 | with self.accelerator.accumulate(self.model): | 91 | with self.accelerator.accumulate(self.model): |
| 92 | loss, acc, bsz = self.loss_fn(batch) | 92 | loss, acc, bsz = self.loss_fn(step, batch) |
| 93 | 93 | ||
| 94 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
| 95 | 95 | ||
| @@ -108,7 +108,7 @@ class LRFinder(): | |||
| 108 | if step >= num_val_batches: | 108 | if step >= num_val_batches: |
| 109 | break | 109 | break |
| 110 | 110 | ||
| 111 | loss, acc, bsz = self.loss_fn(batch, True) | 111 | loss, acc, bsz = self.loss_fn(step, batch, True) |
| 112 | avg_loss.update(loss.detach_(), bsz) | 112 | avg_loss.update(loss.detach_(), bsz) |
| 113 | avg_acc.update(acc.detach_(), bsz) | 113 | avg_acc.update(acc.detach_(), bsz) |
| 114 | 114 | ||
