summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py8
-rw-r--r--train_ti.py12
-rw-r--r--training/lr.py6
3 files changed, 13 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6d9bae8..5e6e35d 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -848,7 +848,7 @@ def main():
848 def on_eval(): 848 def on_eval():
849 tokenizer.eval() 849 tokenizer.eval()
850 850
851 def loop(batch, eval: bool = False): 851 def loop(step: int, batch, eval: bool = False):
852 # Convert images to latent space 852 # Convert images to latent space
853 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 853 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
854 latents = latents * 0.18215 854 latents = latents * 0.18215
@@ -857,7 +857,7 @@ def main():
857 noise = torch.randn_like(latents) 857 noise = torch.randn_like(latents)
858 bsz = latents.shape[0] 858 bsz = latents.shape[0]
859 # Sample a random timestep for each image 859 # Sample a random timestep for each image
860 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None 860 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None
861 timesteps = torch.randint( 861 timesteps = torch.randint(
862 0, 862 0,
863 noise_scheduler.config.num_train_timesteps, 863 noise_scheduler.config.num_train_timesteps,
@@ -1008,7 +1008,7 @@ def main():
1008 1008
1009 for step, batch in enumerate(train_dataloader): 1009 for step, batch in enumerate(train_dataloader):
1010 with accelerator.accumulate(unet): 1010 with accelerator.accumulate(unet):
1011 loss, acc, bsz = loop(batch) 1011 loss, acc, bsz = loop(step, batch)
1012 1012
1013 accelerator.backward(loss) 1013 accelerator.backward(loss)
1014 1014
@@ -1065,7 +1065,7 @@ def main():
1065 1065
1066 with torch.inference_mode(): 1066 with torch.inference_mode():
1067 for step, batch in enumerate(val_dataloader): 1067 for step, batch in enumerate(val_dataloader):
1068 loss, acc, bsz = loop(batch, True) 1068 loss, acc, bsz = loop(step, batch, True)
1069 1069
1070 loss = loss.detach_() 1070 loss = loss.detach_()
1071 acc = acc.detach_() 1071 acc = acc.detach_()
diff --git a/train_ti.py b/train_ti.py
index 5d6eafc..6f116c3 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -598,7 +598,7 @@ def main():
598 ) 598 )
599 599
600 if args.find_lr: 600 if args.find_lr:
601 args.learning_rate = 1e-6 601 args.learning_rate = 1e-5
602 602
603 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 603 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
604 if args.use_8bit_adam: 604 if args.use_8bit_adam:
@@ -799,7 +799,7 @@ def main():
799 def on_eval(): 799 def on_eval():
800 tokenizer.eval() 800 tokenizer.eval()
801 801
802 def loop(batch, eval: bool = False): 802 def loop(step: int, batch, eval: bool = False):
803 # Convert images to latent space 803 # Convert images to latent space
804 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 804 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
805 latents = latents * 0.18215 805 latents = latents * 0.18215
@@ -808,7 +808,7 @@ def main():
808 noise = torch.randn_like(latents) 808 noise = torch.randn_like(latents)
809 bsz = latents.shape[0] 809 bsz = latents.shape[0]
810 # Sample a random timestep for each image 810 # Sample a random timestep for each image
811 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None 811 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None
812 timesteps = torch.randint( 812 timesteps = torch.randint(
813 0, 813 0,
814 noise_scheduler.config.num_train_timesteps, 814 noise_scheduler.config.num_train_timesteps,
@@ -881,7 +881,7 @@ def main():
881 on_train=on_train, 881 on_train=on_train,
882 on_eval=on_eval, 882 on_eval=on_eval,
883 ) 883 )
884 lr_finder.run(end_lr=1e2) 884 lr_finder.run(end_lr=1e3)
885 885
886 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 886 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
887 plt.close() 887 plt.close()
@@ -954,7 +954,7 @@ def main():
954 954
955 for step, batch in enumerate(train_dataloader): 955 for step, batch in enumerate(train_dataloader):
956 with accelerator.accumulate(text_encoder): 956 with accelerator.accumulate(text_encoder):
957 loss, acc, bsz = loop(batch) 957 loss, acc, bsz = loop(step, batch)
958 958
959 accelerator.backward(loss) 959 accelerator.backward(loss)
960 960
@@ -998,7 +998,7 @@ def main():
998 998
999 with torch.inference_mode(): 999 with torch.inference_mode():
1000 for step, batch in enumerate(val_dataloader): 1000 for step, batch in enumerate(val_dataloader):
1001 loss, acc, bsz = loop(batch, True) 1001 loss, acc, bsz = loop(step, batch, True)
1002 1002
1003 loss = loss.detach_() 1003 loss = loss.detach_()
1004 acc = acc.detach_() 1004 acc = acc.detach_()
diff --git a/training/lr.py b/training/lr.py
index a3144ba..c8dc040 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -24,7 +24,7 @@ class LRFinder():
24 optimizer, 24 optimizer,
25 train_dataloader, 25 train_dataloader,
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], None] = noop,
29 on_eval: Callable[[], None] = noop 29 on_eval: Callable[[], None] = noop
30 ): 30 ):
@@ -89,7 +89,7 @@ class LRFinder():
89 break 89 break
90 90
91 with self.accelerator.accumulate(self.model): 91 with self.accelerator.accumulate(self.model):
92 loss, acc, bsz = self.loss_fn(batch) 92 loss, acc, bsz = self.loss_fn(step, batch)
93 93
94 self.accelerator.backward(loss) 94 self.accelerator.backward(loss)
95 95
@@ -108,7 +108,7 @@ class LRFinder():
108 if step >= num_val_batches: 108 if step >= num_val_batches:
109 break 109 break
110 110
111 loss, acc, bsz = self.loss_fn(batch, True) 111 loss, acc, bsz = self.loss_fn(step, batch, True)
112 avg_loss.update(loss.detach_(), bsz) 112 avg_loss.update(loss.detach_(), bsz)
113 avg_acc.update(acc.detach_(), bsz) 113 avg_acc.update(acc.detach_(), bsz)
114 114