summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py4
-rw-r--r--train_dreambooth.py9
-rw-r--r--train_ti.py22
-rw-r--r--training/lr.py20
4 files changed, 37 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py
index ed8e93d..9ad7dd6 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -122,6 +122,7 @@ class VlpnDataModule():
122 bucket_max_pixels: Optional[int] = None, 122 bucket_max_pixels: Optional[int] = None,
123 progressive_buckets: bool = False, 123 progressive_buckets: bool = False,
124 dropout: float = 0, 124 dropout: float = 0,
125 shuffle: bool = False,
125 interpolation: str = "bicubic", 126 interpolation: str = "bicubic",
126 template_key: str = "template", 127 template_key: str = "template",
127 valid_set_size: Optional[int] = None, 128 valid_set_size: Optional[int] = None,
@@ -150,6 +151,7 @@ class VlpnDataModule():
150 self.bucket_max_pixels = bucket_max_pixels 151 self.bucket_max_pixels = bucket_max_pixels
151 self.progressive_buckets = progressive_buckets 152 self.progressive_buckets = progressive_buckets
152 self.dropout = dropout 153 self.dropout = dropout
154 self.shuffle = shuffle
153 self.template_key = template_key 155 self.template_key = template_key
154 self.interpolation = interpolation 156 self.interpolation = interpolation
155 self.valid_set_size = valid_set_size 157 self.valid_set_size = valid_set_size
@@ -240,7 +242,7 @@ class VlpnDataModule():
240 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 242 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
241 batch_size=self.batch_size, generator=generator, 243 batch_size=self.batch_size, generator=generator,
242 size=self.size, interpolation=self.interpolation, 244 size=self.size, interpolation=self.interpolation,
243 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, 245 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
244 ) 246 )
245 247
246 val_dataset = VlpnDataset( 248 val_dataset = VlpnDataset(
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1a1f516..48a513c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -133,6 +133,12 @@ def parse_args():
133 help="Tag dropout probability.", 133 help="Tag dropout probability.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
136 "--tag_shuffle",
137 type="store_true",
138 default=True,
139 help="Shuffle tags.",
140 )
141 parser.add_argument(
136 "--vector_dropout", 142 "--vector_dropout",
137 type=int, 143 type=int,
138 default=0, 144 default=0,
@@ -398,7 +404,7 @@ def parse_args():
398 parser.add_argument( 404 parser.add_argument(
399 "--sample_steps", 405 "--sample_steps",
400 type=int, 406 type=int,
401 default=15, 407 default=20,
402 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 408 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
403 ) 409 )
404 parser.add_argument( 410 parser.add_argument(
@@ -768,6 +774,7 @@ def main():
768 bucket_step_size=args.bucket_step_size, 774 bucket_step_size=args.bucket_step_size,
769 bucket_max_pixels=args.bucket_max_pixels, 775 bucket_max_pixels=args.bucket_max_pixels,
770 dropout=args.tag_dropout, 776 dropout=args.tag_dropout,
777 shuffle=args.tag_shuffle,
771 template_key=args.train_data_template, 778 template_key=args.train_data_template,
772 valid_set_size=args.valid_set_size, 779 valid_set_size=args.valid_set_size,
773 valid_set_repeat=args.valid_set_repeat, 780 valid_set_repeat=args.valid_set_repeat,
diff --git a/train_ti.py b/train_ti.py
index df8d443..35be74c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -169,6 +169,11 @@ def parse_args():
169 help="Tag dropout probability.", 169 help="Tag dropout probability.",
170 ) 170 )
171 parser.add_argument( 171 parser.add_argument(
172 "--tag_shuffle",
173 type="store_true",
174 help="Shuffle tags.",
175 )
176 parser.add_argument(
172 "--vector_dropout", 177 "--vector_dropout",
173 type=int, 178 type=int,
174 default=0, 179 default=0,
@@ -395,7 +400,7 @@ def parse_args():
395 parser.add_argument( 400 parser.add_argument(
396 "--sample_steps", 401 "--sample_steps",
397 type=int, 402 type=int,
398 default=15, 403 default=20,
399 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 404 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
400 ) 405 )
401 parser.add_argument( 406 parser.add_argument(
@@ -745,6 +750,7 @@ def main():
745 bucket_step_size=args.bucket_step_size, 750 bucket_step_size=args.bucket_step_size,
746 bucket_max_pixels=args.bucket_max_pixels, 751 bucket_max_pixels=args.bucket_max_pixels,
747 dropout=args.tag_dropout, 752 dropout=args.tag_dropout,
753 shuffle=args.tag_shuffle,
748 template_key=args.train_data_template, 754 template_key=args.train_data_template,
749 valid_set_size=args.valid_set_size, 755 valid_set_size=args.valid_set_size,
750 valid_set_repeat=args.valid_set_repeat, 756 valid_set_repeat=args.valid_set_repeat,
@@ -860,6 +866,12 @@ def main():
860 finally: 866 finally:
861 pass 867 pass
862 868
869 def on_clip():
870 accelerator.clip_grad_norm_(
871 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
872 args.max_grad_norm
873 )
874
863 loop = partial( 875 loop = partial(
864 run_model, 876 run_model,
865 vae, 877 vae,
@@ -894,8 +906,9 @@ def main():
894 loop, 906 loop,
895 on_train=on_train, 907 on_train=on_train,
896 on_eval=on_eval, 908 on_eval=on_eval,
909 on_clip=on_clip,
897 ) 910 )
898 lr_finder.run(num_epochs=200, end_lr=1e3) 911 lr_finder.run(num_epochs=100, end_lr=1e3)
899 912
900 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 913 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
901 plt.close() 914 plt.close()
@@ -975,10 +988,7 @@ def main():
975 accelerator.backward(loss) 988 accelerator.backward(loss)
976 989
977 if accelerator.sync_gradients: 990 if accelerator.sync_gradients:
978 accelerator.clip_grad_norm_( 991 on_clip()
979 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
980 args.max_grad_norm
981 )
982 992
983 optimizer.step() 993 optimizer.step()
984 if not accelerator.optimizer_step_was_skipped: 994 if not accelerator.optimizer_step_was_skipped:
diff --git a/training/lr.py b/training/lr.py
index 68e0f72..dfb1743 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -48,7 +48,7 @@ class LRFinder():
48 skip_start: int = 10, 48 skip_start: int = 10,
49 skip_end: int = 5, 49 skip_end: int = 5,
50 num_epochs: int = 100, 50 num_epochs: int = 100,
51 num_train_batches: int = 1, 51 num_train_batches: int = math.inf,
52 num_val_batches: int = math.inf, 52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05, 53 smooth_f: float = 0.05,
54 ): 54 ):
@@ -156,6 +156,15 @@ class LRFinder():
156 # self.model.load_state_dict(self.model_state) 156 # self.model.load_state_dict(self.model_state)
157 # self.optimizer.load_state_dict(self.optimizer_state) 157 # self.optimizer.load_state_dict(self.optimizer_state)
158 158
159 if skip_end == 0:
160 lrs = lrs[skip_start:]
161 losses = losses[skip_start:]
162 accs = accs[skip_start:]
163 else:
164 lrs = lrs[skip_start:-skip_end]
165 losses = losses[skip_start:-skip_end]
166 accs = accs[skip_start:-skip_end]
167
159 fig, ax_loss = plt.subplots() 168 fig, ax_loss = plt.subplots()
160 ax_acc = ax_loss.twinx() 169 ax_acc = ax_loss.twinx()
161 170
@@ -171,15 +180,6 @@ class LRFinder():
171 print("LR suggestion: steepest gradient") 180 print("LR suggestion: steepest gradient")
172 min_grad_idx = None 181 min_grad_idx = None
173 182
174 if skip_end == 0:
175 lrs = lrs[skip_start:]
176 losses = losses[skip_start:]
177 accs = accs[skip_start:]
178 else:
179 lrs = lrs[skip_start:-skip_end]
180 losses = losses[skip_start:-skip_end]
181 accs = accs[skip_start:-skip_end]
182
183 try: 183 try:
184 min_grad_idx = np.gradient(np.array(losses)).argmin() 184 min_grad_idx = np.gradient(np.array(losses)).argmin()
185 except ValueError: 185 except ValueError: