diff options
| -rw-r--r-- | data/csv.py | 4 | ||||
| -rw-r--r-- | train_dreambooth.py | 9 | ||||
| -rw-r--r-- | train_ti.py | 22 | ||||
| -rw-r--r-- | training/lr.py | 20 |
4 files changed, 37 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index ed8e93d..9ad7dd6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -122,6 +122,7 @@ class VlpnDataModule(): | |||
| 122 | bucket_max_pixels: Optional[int] = None, | 122 | bucket_max_pixels: Optional[int] = None, |
| 123 | progressive_buckets: bool = False, | 123 | progressive_buckets: bool = False, |
| 124 | dropout: float = 0, | 124 | dropout: float = 0, |
| 125 | shuffle: bool = False, | ||
| 125 | interpolation: str = "bicubic", | 126 | interpolation: str = "bicubic", |
| 126 | template_key: str = "template", | 127 | template_key: str = "template", |
| 127 | valid_set_size: Optional[int] = None, | 128 | valid_set_size: Optional[int] = None, |
| @@ -150,6 +151,7 @@ class VlpnDataModule(): | |||
| 150 | self.bucket_max_pixels = bucket_max_pixels | 151 | self.bucket_max_pixels = bucket_max_pixels |
| 151 | self.progressive_buckets = progressive_buckets | 152 | self.progressive_buckets = progressive_buckets |
| 152 | self.dropout = dropout | 153 | self.dropout = dropout |
| 154 | self.shuffle = shuffle | ||
| 153 | self.template_key = template_key | 155 | self.template_key = template_key |
| 154 | self.interpolation = interpolation | 156 | self.interpolation = interpolation |
| 155 | self.valid_set_size = valid_set_size | 157 | self.valid_set_size = valid_set_size |
| @@ -240,7 +242,7 @@ class VlpnDataModule(): | |||
| 240 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 242 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 241 | batch_size=self.batch_size, generator=generator, | 243 | batch_size=self.batch_size, generator=generator, |
| 242 | size=self.size, interpolation=self.interpolation, | 244 | size=self.size, interpolation=self.interpolation, |
| 243 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 245 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
| 244 | ) | 246 | ) |
| 245 | 247 | ||
| 246 | val_dataset = VlpnDataset( | 248 | val_dataset = VlpnDataset( |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a1f516..48a513c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -133,6 +133,12 @@ def parse_args(): | |||
| 133 | help="Tag dropout probability.", | 133 | help="Tag dropout probability.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| 136 | "--tag_shuffle", | ||
| 137 | type="store_true", | ||
| 138 | default=True, | ||
| 139 | help="Shuffle tags.", | ||
| 140 | ) | ||
| 141 | parser.add_argument( | ||
| 136 | "--vector_dropout", | 142 | "--vector_dropout", |
| 137 | type=int, | 143 | type=int, |
| 138 | default=0, | 144 | default=0, |
| @@ -398,7 +404,7 @@ def parse_args(): | |||
| 398 | parser.add_argument( | 404 | parser.add_argument( |
| 399 | "--sample_steps", | 405 | "--sample_steps", |
| 400 | type=int, | 406 | type=int, |
| 401 | default=15, | 407 | default=20, |
| 402 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 408 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 403 | ) | 409 | ) |
| 404 | parser.add_argument( | 410 | parser.add_argument( |
| @@ -768,6 +774,7 @@ def main(): | |||
| 768 | bucket_step_size=args.bucket_step_size, | 774 | bucket_step_size=args.bucket_step_size, |
| 769 | bucket_max_pixels=args.bucket_max_pixels, | 775 | bucket_max_pixels=args.bucket_max_pixels, |
| 770 | dropout=args.tag_dropout, | 776 | dropout=args.tag_dropout, |
| 777 | shuffle=args.tag_shuffle, | ||
| 771 | template_key=args.train_data_template, | 778 | template_key=args.train_data_template, |
| 772 | valid_set_size=args.valid_set_size, | 779 | valid_set_size=args.valid_set_size, |
| 773 | valid_set_repeat=args.valid_set_repeat, | 780 | valid_set_repeat=args.valid_set_repeat, |
diff --git a/train_ti.py b/train_ti.py index df8d443..35be74c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -169,6 +169,11 @@ def parse_args(): | |||
| 169 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
| 170 | ) | 170 | ) |
| 171 | parser.add_argument( | 171 | parser.add_argument( |
| 172 | "--tag_shuffle", | ||
| 173 | type="store_true", | ||
| 174 | help="Shuffle tags.", | ||
| 175 | ) | ||
| 176 | parser.add_argument( | ||
| 172 | "--vector_dropout", | 177 | "--vector_dropout", |
| 173 | type=int, | 178 | type=int, |
| 174 | default=0, | 179 | default=0, |
| @@ -395,7 +400,7 @@ def parse_args(): | |||
| 395 | parser.add_argument( | 400 | parser.add_argument( |
| 396 | "--sample_steps", | 401 | "--sample_steps", |
| 397 | type=int, | 402 | type=int, |
| 398 | default=15, | 403 | default=20, |
| 399 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 404 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 400 | ) | 405 | ) |
| 401 | parser.add_argument( | 406 | parser.add_argument( |
| @@ -745,6 +750,7 @@ def main(): | |||
| 745 | bucket_step_size=args.bucket_step_size, | 750 | bucket_step_size=args.bucket_step_size, |
| 746 | bucket_max_pixels=args.bucket_max_pixels, | 751 | bucket_max_pixels=args.bucket_max_pixels, |
| 747 | dropout=args.tag_dropout, | 752 | dropout=args.tag_dropout, |
| 753 | shuffle=args.tag_shuffle, | ||
| 748 | template_key=args.train_data_template, | 754 | template_key=args.train_data_template, |
| 749 | valid_set_size=args.valid_set_size, | 755 | valid_set_size=args.valid_set_size, |
| 750 | valid_set_repeat=args.valid_set_repeat, | 756 | valid_set_repeat=args.valid_set_repeat, |
| @@ -860,6 +866,12 @@ def main(): | |||
| 860 | finally: | 866 | finally: |
| 861 | pass | 867 | pass |
| 862 | 868 | ||
| 869 | def on_clip(): | ||
| 870 | accelerator.clip_grad_norm_( | ||
| 871 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 872 | args.max_grad_norm | ||
| 873 | ) | ||
| 874 | |||
| 863 | loop = partial( | 875 | loop = partial( |
| 864 | run_model, | 876 | run_model, |
| 865 | vae, | 877 | vae, |
| @@ -894,8 +906,9 @@ def main(): | |||
| 894 | loop, | 906 | loop, |
| 895 | on_train=on_train, | 907 | on_train=on_train, |
| 896 | on_eval=on_eval, | 908 | on_eval=on_eval, |
| 909 | on_clip=on_clip, | ||
| 897 | ) | 910 | ) |
| 898 | lr_finder.run(num_epochs=200, end_lr=1e3) | 911 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 899 | 912 | ||
| 900 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 913 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 901 | plt.close() | 914 | plt.close() |
| @@ -975,10 +988,7 @@ def main(): | |||
| 975 | accelerator.backward(loss) | 988 | accelerator.backward(loss) |
| 976 | 989 | ||
| 977 | if accelerator.sync_gradients: | 990 | if accelerator.sync_gradients: |
| 978 | accelerator.clip_grad_norm_( | 991 | on_clip() |
| 979 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 980 | args.max_grad_norm | ||
| 981 | ) | ||
| 982 | 992 | ||
| 983 | optimizer.step() | 993 | optimizer.step() |
| 984 | if not accelerator.optimizer_step_was_skipped: | 994 | if not accelerator.optimizer_step_was_skipped: |
diff --git a/training/lr.py b/training/lr.py index 68e0f72..dfb1743 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -48,7 +48,7 @@ class LRFinder(): | |||
| 48 | skip_start: int = 10, | 48 | skip_start: int = 10, |
| 49 | skip_end: int = 5, | 49 | skip_end: int = 5, |
| 50 | num_epochs: int = 100, | 50 | num_epochs: int = 100, |
| 51 | num_train_batches: int = 1, | 51 | num_train_batches: int = math.inf, |
| 52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
| 53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
| 54 | ): | 54 | ): |
| @@ -156,6 +156,15 @@ class LRFinder(): | |||
| 156 | # self.model.load_state_dict(self.model_state) | 156 | # self.model.load_state_dict(self.model_state) |
| 157 | # self.optimizer.load_state_dict(self.optimizer_state) | 157 | # self.optimizer.load_state_dict(self.optimizer_state) |
| 158 | 158 | ||
| 159 | if skip_end == 0: | ||
| 160 | lrs = lrs[skip_start:] | ||
| 161 | losses = losses[skip_start:] | ||
| 162 | accs = accs[skip_start:] | ||
| 163 | else: | ||
| 164 | lrs = lrs[skip_start:-skip_end] | ||
| 165 | losses = losses[skip_start:-skip_end] | ||
| 166 | accs = accs[skip_start:-skip_end] | ||
| 167 | |||
| 159 | fig, ax_loss = plt.subplots() | 168 | fig, ax_loss = plt.subplots() |
| 160 | ax_acc = ax_loss.twinx() | 169 | ax_acc = ax_loss.twinx() |
| 161 | 170 | ||
| @@ -171,15 +180,6 @@ class LRFinder(): | |||
| 171 | print("LR suggestion: steepest gradient") | 180 | print("LR suggestion: steepest gradient") |
| 172 | min_grad_idx = None | 181 | min_grad_idx = None |
| 173 | 182 | ||
| 174 | if skip_end == 0: | ||
| 175 | lrs = lrs[skip_start:] | ||
| 176 | losses = losses[skip_start:] | ||
| 177 | accs = accs[skip_start:] | ||
| 178 | else: | ||
| 179 | lrs = lrs[skip_start:-skip_end] | ||
| 180 | losses = losses[skip_start:-skip_end] | ||
| 181 | accs = accs[skip_start:-skip_end] | ||
| 182 | |||
| 183 | try: | 183 | try: |
| 184 | min_grad_idx = np.gradient(np.array(losses)).argmin() | 184 | min_grad_idx = np.gradient(np.array(losses)).argmin() |
| 185 | except ValueError: | 185 | except ValueError: |
