diff options
-rw-r--r-- | data/csv.py | 4 | ||||
-rw-r--r-- | train_dreambooth.py | 9 | ||||
-rw-r--r-- | train_ti.py | 22 | ||||
-rw-r--r-- | training/lr.py | 20 |
4 files changed, 37 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index ed8e93d..9ad7dd6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -122,6 +122,7 @@ class VlpnDataModule(): | |||
122 | bucket_max_pixels: Optional[int] = None, | 122 | bucket_max_pixels: Optional[int] = None, |
123 | progressive_buckets: bool = False, | 123 | progressive_buckets: bool = False, |
124 | dropout: float = 0, | 124 | dropout: float = 0, |
125 | shuffle: bool = False, | ||
125 | interpolation: str = "bicubic", | 126 | interpolation: str = "bicubic", |
126 | template_key: str = "template", | 127 | template_key: str = "template", |
127 | valid_set_size: Optional[int] = None, | 128 | valid_set_size: Optional[int] = None, |
@@ -150,6 +151,7 @@ class VlpnDataModule(): | |||
150 | self.bucket_max_pixels = bucket_max_pixels | 151 | self.bucket_max_pixels = bucket_max_pixels |
151 | self.progressive_buckets = progressive_buckets | 152 | self.progressive_buckets = progressive_buckets |
152 | self.dropout = dropout | 153 | self.dropout = dropout |
154 | self.shuffle = shuffle | ||
153 | self.template_key = template_key | 155 | self.template_key = template_key |
154 | self.interpolation = interpolation | 156 | self.interpolation = interpolation |
155 | self.valid_set_size = valid_set_size | 157 | self.valid_set_size = valid_set_size |
@@ -240,7 +242,7 @@ class VlpnDataModule(): | |||
240 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 242 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
241 | batch_size=self.batch_size, generator=generator, | 243 | batch_size=self.batch_size, generator=generator, |
242 | size=self.size, interpolation=self.interpolation, | 244 | size=self.size, interpolation=self.interpolation, |
243 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 245 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
244 | ) | 246 | ) |
245 | 247 | ||
246 | val_dataset = VlpnDataset( | 248 | val_dataset = VlpnDataset( |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a1f516..48a513c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -133,6 +133,12 @@ def parse_args(): | |||
133 | help="Tag dropout probability.", | 133 | help="Tag dropout probability.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
136 | "--tag_shuffle", | ||
137 | type="store_true", | ||
138 | default=True, | ||
139 | help="Shuffle tags.", | ||
140 | ) | ||
141 | parser.add_argument( | ||
136 | "--vector_dropout", | 142 | "--vector_dropout", |
137 | type=int, | 143 | type=int, |
138 | default=0, | 144 | default=0, |
@@ -398,7 +404,7 @@ def parse_args(): | |||
398 | parser.add_argument( | 404 | parser.add_argument( |
399 | "--sample_steps", | 405 | "--sample_steps", |
400 | type=int, | 406 | type=int, |
401 | default=15, | 407 | default=20, |
402 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 408 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
403 | ) | 409 | ) |
404 | parser.add_argument( | 410 | parser.add_argument( |
@@ -768,6 +774,7 @@ def main(): | |||
768 | bucket_step_size=args.bucket_step_size, | 774 | bucket_step_size=args.bucket_step_size, |
769 | bucket_max_pixels=args.bucket_max_pixels, | 775 | bucket_max_pixels=args.bucket_max_pixels, |
770 | dropout=args.tag_dropout, | 776 | dropout=args.tag_dropout, |
777 | shuffle=args.tag_shuffle, | ||
771 | template_key=args.train_data_template, | 778 | template_key=args.train_data_template, |
772 | valid_set_size=args.valid_set_size, | 779 | valid_set_size=args.valid_set_size, |
773 | valid_set_repeat=args.valid_set_repeat, | 780 | valid_set_repeat=args.valid_set_repeat, |
diff --git a/train_ti.py b/train_ti.py index df8d443..35be74c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -169,6 +169,11 @@ def parse_args(): | |||
169 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
170 | ) | 170 | ) |
171 | parser.add_argument( | 171 | parser.add_argument( |
172 | "--tag_shuffle", | ||
173 | type="store_true", | ||
174 | help="Shuffle tags.", | ||
175 | ) | ||
176 | parser.add_argument( | ||
172 | "--vector_dropout", | 177 | "--vector_dropout", |
173 | type=int, | 178 | type=int, |
174 | default=0, | 179 | default=0, |
@@ -395,7 +400,7 @@ def parse_args(): | |||
395 | parser.add_argument( | 400 | parser.add_argument( |
396 | "--sample_steps", | 401 | "--sample_steps", |
397 | type=int, | 402 | type=int, |
398 | default=15, | 403 | default=20, |
399 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 404 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
400 | ) | 405 | ) |
401 | parser.add_argument( | 406 | parser.add_argument( |
@@ -745,6 +750,7 @@ def main(): | |||
745 | bucket_step_size=args.bucket_step_size, | 750 | bucket_step_size=args.bucket_step_size, |
746 | bucket_max_pixels=args.bucket_max_pixels, | 751 | bucket_max_pixels=args.bucket_max_pixels, |
747 | dropout=args.tag_dropout, | 752 | dropout=args.tag_dropout, |
753 | shuffle=args.tag_shuffle, | ||
748 | template_key=args.train_data_template, | 754 | template_key=args.train_data_template, |
749 | valid_set_size=args.valid_set_size, | 755 | valid_set_size=args.valid_set_size, |
750 | valid_set_repeat=args.valid_set_repeat, | 756 | valid_set_repeat=args.valid_set_repeat, |
@@ -860,6 +866,12 @@ def main(): | |||
860 | finally: | 866 | finally: |
861 | pass | 867 | pass |
862 | 868 | ||
869 | def on_clip(): | ||
870 | accelerator.clip_grad_norm_( | ||
871 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
872 | args.max_grad_norm | ||
873 | ) | ||
874 | |||
863 | loop = partial( | 875 | loop = partial( |
864 | run_model, | 876 | run_model, |
865 | vae, | 877 | vae, |
@@ -894,8 +906,9 @@ def main(): | |||
894 | loop, | 906 | loop, |
895 | on_train=on_train, | 907 | on_train=on_train, |
896 | on_eval=on_eval, | 908 | on_eval=on_eval, |
909 | on_clip=on_clip, | ||
897 | ) | 910 | ) |
898 | lr_finder.run(num_epochs=200, end_lr=1e3) | 911 | lr_finder.run(num_epochs=100, end_lr=1e3) |
899 | 912 | ||
900 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 913 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
901 | plt.close() | 914 | plt.close() |
@@ -975,10 +988,7 @@ def main(): | |||
975 | accelerator.backward(loss) | 988 | accelerator.backward(loss) |
976 | 989 | ||
977 | if accelerator.sync_gradients: | 990 | if accelerator.sync_gradients: |
978 | accelerator.clip_grad_norm_( | 991 | on_clip() |
979 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
980 | args.max_grad_norm | ||
981 | ) | ||
982 | 992 | ||
983 | optimizer.step() | 993 | optimizer.step() |
984 | if not accelerator.optimizer_step_was_skipped: | 994 | if not accelerator.optimizer_step_was_skipped: |
diff --git a/training/lr.py b/training/lr.py index 68e0f72..dfb1743 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -48,7 +48,7 @@ class LRFinder(): | |||
48 | skip_start: int = 10, | 48 | skip_start: int = 10, |
49 | skip_end: int = 5, | 49 | skip_end: int = 5, |
50 | num_epochs: int = 100, | 50 | num_epochs: int = 100, |
51 | num_train_batches: int = 1, | 51 | num_train_batches: int = math.inf, |
52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
54 | ): | 54 | ): |
@@ -156,6 +156,15 @@ class LRFinder(): | |||
156 | # self.model.load_state_dict(self.model_state) | 156 | # self.model.load_state_dict(self.model_state) |
157 | # self.optimizer.load_state_dict(self.optimizer_state) | 157 | # self.optimizer.load_state_dict(self.optimizer_state) |
158 | 158 | ||
159 | if skip_end == 0: | ||
160 | lrs = lrs[skip_start:] | ||
161 | losses = losses[skip_start:] | ||
162 | accs = accs[skip_start:] | ||
163 | else: | ||
164 | lrs = lrs[skip_start:-skip_end] | ||
165 | losses = losses[skip_start:-skip_end] | ||
166 | accs = accs[skip_start:-skip_end] | ||
167 | |||
159 | fig, ax_loss = plt.subplots() | 168 | fig, ax_loss = plt.subplots() |
160 | ax_acc = ax_loss.twinx() | 169 | ax_acc = ax_loss.twinx() |
161 | 170 | ||
@@ -171,15 +180,6 @@ class LRFinder(): | |||
171 | print("LR suggestion: steepest gradient") | 180 | print("LR suggestion: steepest gradient") |
172 | min_grad_idx = None | 181 | min_grad_idx = None |
173 | 182 | ||
174 | if skip_end == 0: | ||
175 | lrs = lrs[skip_start:] | ||
176 | losses = losses[skip_start:] | ||
177 | accs = accs[skip_start:] | ||
178 | else: | ||
179 | lrs = lrs[skip_start:-skip_end] | ||
180 | losses = losses[skip_start:-skip_end] | ||
181 | accs = accs[skip_start:-skip_end] | ||
182 | |||
183 | try: | 183 | try: |
184 | min_grad_idx = np.gradient(np.array(losses)).argmin() | 184 | min_grad_idx = np.gradient(np.array(losses)).argmin() |
185 | except ValueError: | 185 | except ValueError: |