From 33e7d2ed37e32657ca94d92815043026c4cea7c0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 10 Jan 2023 09:22:02 +0100 Subject: Added arg to disable tag shuffling --- data/csv.py | 4 +++- train_dreambooth.py | 9 ++++++++- train_ti.py | 22 ++++++++++++++++------ training/lr.py | 20 ++++++++++---------- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/data/csv.py b/data/csv.py index ed8e93d..9ad7dd6 100644 --- a/data/csv.py +++ b/data/csv.py @@ -122,6 +122,7 @@ class VlpnDataModule(): bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, + shuffle: bool = False, interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, @@ -150,6 +151,7 @@ class VlpnDataModule(): self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets self.dropout = dropout + self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -240,7 +242,7 @@ class VlpnDataModule(): bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, - num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, + num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) val_dataset = VlpnDataset( diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a1f516..48a513c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -132,6 +132,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--tag_shuffle", + type="store_true", + default=True, + help="Shuffle tags.", + ) parser.add_argument( "--vector_dropout", type=int, @@ -398,7 +404,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=15, + default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -768,6 +774,7 @@ def main(): bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, + shuffle=args.tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, diff --git a/train_ti.py b/train_ti.py index df8d443..35be74c 100644 --- a/train_ti.py +++ b/train_ti.py @@ -168,6 +168,11 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--tag_shuffle", + type="store_true", + help="Shuffle tags.", + ) parser.add_argument( "--vector_dropout", type=int, @@ -395,7 +400,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=15, + default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -745,6 +750,7 @@ def main(): bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, + shuffle=args.tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, @@ -860,6 +866,12 @@ def main(): finally: pass + def on_clip(): + accelerator.clip_grad_norm_( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + args.max_grad_norm + ) + loop = partial( run_model, vae, @@ -894,8 +906,9 @@ def main(): loop, on_train=on_train, on_eval=on_eval, + on_clip=on_clip, ) - lr_finder.run(num_epochs=200, end_lr=1e3) + lr_finder.run(num_epochs=100, end_lr=1e3) plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() @@ -975,10 +988,7 @@ def main(): accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - args.max_grad_norm - ) + on_clip() optimizer.step() if not accelerator.optimizer_step_was_skipped: diff --git a/training/lr.py b/training/lr.py index 68e0f72..dfb1743 100644 --- a/training/lr.py +++ b/training/lr.py @@ -48,7 +48,7 @@ class LRFinder(): skip_start: int = 10, skip_end: int = 5, num_epochs: int = 100, - num_train_batches: int = 1, + num_train_batches: int = math.inf, num_val_batches: int = math.inf, smooth_f: float = 0.05, ): @@ -156,6 +156,15 @@ class LRFinder(): # self.model.load_state_dict(self.model_state) # self.optimizer.load_state_dict(self.optimizer_state) + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() @@ -171,15 +180,6 @@ class LRFinder(): print("LR suggestion: steepest gradient") min_grad_idx = None - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - accs = accs[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] - accs = accs[skip_start:-skip_end] - try: min_grad_idx = np.gradient(np.array(losses)).argmin() except ValueError: -- cgit v1.2.3-70-g09d2