diff options
-rw-r--r-- | train_dreambooth.py | 8 | ||||
-rw-r--r-- | train_ti.py | 12 | ||||
-rw-r--r-- | training/lr.py | 6 |
3 files changed, 13 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d9bae8..5e6e35d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -848,7 +848,7 @@ def main(): | |||
848 | def on_eval(): | 848 | def on_eval(): |
849 | tokenizer.eval() | 849 | tokenizer.eval() |
850 | 850 | ||
851 | def loop(batch, eval: bool = False): | 851 | def loop(step: int, batch, eval: bool = False): |
852 | # Convert images to latent space | 852 | # Convert images to latent space |
853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
854 | latents = latents * 0.18215 | 854 | latents = latents * 0.18215 |
@@ -857,7 +857,7 @@ def main(): | |||
857 | noise = torch.randn_like(latents) | 857 | noise = torch.randn_like(latents) |
858 | bsz = latents.shape[0] | 858 | bsz = latents.shape[0] |
859 | # Sample a random timestep for each image | 859 | # Sample a random timestep for each image |
860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
861 | timesteps = torch.randint( | 861 | timesteps = torch.randint( |
862 | 0, | 862 | 0, |
863 | noise_scheduler.config.num_train_timesteps, | 863 | noise_scheduler.config.num_train_timesteps, |
@@ -1008,7 +1008,7 @@ def main(): | |||
1008 | 1008 | ||
1009 | for step, batch in enumerate(train_dataloader): | 1009 | for step, batch in enumerate(train_dataloader): |
1010 | with accelerator.accumulate(unet): | 1010 | with accelerator.accumulate(unet): |
1011 | loss, acc, bsz = loop(batch) | 1011 | loss, acc, bsz = loop(step, batch) |
1012 | 1012 | ||
1013 | accelerator.backward(loss) | 1013 | accelerator.backward(loss) |
1014 | 1014 | ||
@@ -1065,7 +1065,7 @@ def main(): | |||
1065 | 1065 | ||
1066 | with torch.inference_mode(): | 1066 | with torch.inference_mode(): |
1067 | for step, batch in enumerate(val_dataloader): | 1067 | for step, batch in enumerate(val_dataloader): |
1068 | loss, acc, bsz = loop(batch, True) | 1068 | loss, acc, bsz = loop(step, batch, True) |
1069 | 1069 | ||
1070 | loss = loss.detach_() | 1070 | loss = loss.detach_() |
1071 | acc = acc.detach_() | 1071 | acc = acc.detach_() |
diff --git a/train_ti.py b/train_ti.py index 5d6eafc..6f116c3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -598,7 +598,7 @@ def main(): | |||
598 | ) | 598 | ) |
599 | 599 | ||
600 | if args.find_lr: | 600 | if args.find_lr: |
601 | args.learning_rate = 1e-6 | 601 | args.learning_rate = 1e-5 |
602 | 602 | ||
603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
604 | if args.use_8bit_adam: | 604 | if args.use_8bit_adam: |
@@ -799,7 +799,7 @@ def main(): | |||
799 | def on_eval(): | 799 | def on_eval(): |
800 | tokenizer.eval() | 800 | tokenizer.eval() |
801 | 801 | ||
802 | def loop(batch, eval: bool = False): | 802 | def loop(step: int, batch, eval: bool = False): |
803 | # Convert images to latent space | 803 | # Convert images to latent space |
804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
805 | latents = latents * 0.18215 | 805 | latents = latents * 0.18215 |
@@ -808,7 +808,7 @@ def main(): | |||
808 | noise = torch.randn_like(latents) | 808 | noise = torch.randn_like(latents) |
809 | bsz = latents.shape[0] | 809 | bsz = latents.shape[0] |
810 | # Sample a random timestep for each image | 810 | # Sample a random timestep for each image |
811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None | 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None |
812 | timesteps = torch.randint( | 812 | timesteps = torch.randint( |
813 | 0, | 813 | 0, |
814 | noise_scheduler.config.num_train_timesteps, | 814 | noise_scheduler.config.num_train_timesteps, |
@@ -881,7 +881,7 @@ def main(): | |||
881 | on_train=on_train, | 881 | on_train=on_train, |
882 | on_eval=on_eval, | 882 | on_eval=on_eval, |
883 | ) | 883 | ) |
884 | lr_finder.run(end_lr=1e2) | 884 | lr_finder.run(end_lr=1e3) |
885 | 885 | ||
886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 886 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
887 | plt.close() | 887 | plt.close() |
@@ -954,7 +954,7 @@ def main(): | |||
954 | 954 | ||
955 | for step, batch in enumerate(train_dataloader): | 955 | for step, batch in enumerate(train_dataloader): |
956 | with accelerator.accumulate(text_encoder): | 956 | with accelerator.accumulate(text_encoder): |
957 | loss, acc, bsz = loop(batch) | 957 | loss, acc, bsz = loop(step, batch) |
958 | 958 | ||
959 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
960 | 960 | ||
@@ -998,7 +998,7 @@ def main(): | |||
998 | 998 | ||
999 | with torch.inference_mode(): | 999 | with torch.inference_mode(): |
1000 | for step, batch in enumerate(val_dataloader): | 1000 | for step, batch in enumerate(val_dataloader): |
1001 | loss, acc, bsz = loop(batch, True) | 1001 | loss, acc, bsz = loop(step, batch, True) |
1002 | 1002 | ||
1003 | loss = loss.detach_() | 1003 | loss = loss.detach_() |
1004 | acc = acc.detach_() | 1004 | acc = acc.detach_() |
diff --git a/training/lr.py b/training/lr.py index a3144ba..c8dc040 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -24,7 +24,7 @@ class LRFinder(): | |||
24 | optimizer, | 24 | optimizer, |
25 | train_dataloader, | 25 | train_dataloader, |
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
29 | on_eval: Callable[[], None] = noop | 29 | on_eval: Callable[[], None] = noop |
30 | ): | 30 | ): |
@@ -89,7 +89,7 @@ class LRFinder(): | |||
89 | break | 89 | break |
90 | 90 | ||
91 | with self.accelerator.accumulate(self.model): | 91 | with self.accelerator.accumulate(self.model): |
92 | loss, acc, bsz = self.loss_fn(batch) | 92 | loss, acc, bsz = self.loss_fn(step, batch) |
93 | 93 | ||
94 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
95 | 95 | ||
@@ -108,7 +108,7 @@ class LRFinder(): | |||
108 | if step >= num_val_batches: | 108 | if step >= num_val_batches: |
109 | break | 109 | break |
110 | 110 | ||
111 | loss, acc, bsz = self.loss_fn(batch, True) | 111 | loss, acc, bsz = self.loss_fn(step, batch, True) |
112 | avg_loss.update(loss.detach_(), bsz) | 112 | avg_loss.update(loss.detach_(), bsz) |
113 | avg_acc.update(acc.detach_(), bsz) | 113 | avg_acc.update(acc.detach_(), bsz) |
114 | 114 | ||