From a72b6260c117cabe4fcb2996cce4f870986df99b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 3 Jan 2023 12:40:16 +0100 Subject: Added vector dropout --- models/clip/tokenizer.py | 27 +++++++++++++++++++++++++-- train_dreambooth.py | 24 +++++++++++++++++++----- train_ti.py | 24 +++++++++++++++++++----- training/lr.py | 9 ++++++--- 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index bd0bd21..11a3df0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -6,6 +6,12 @@ import numpy as np from transformers import CLIPTokenizer +def dropout(tokens: list[int], dropout: float): + if dropout != 0: + tokens = [token for token in tokens if np.random.random() > dropout] + return tokens + + def shuffle_all(tokens: list[int]): if len(tokens) >= 2: tokens = copy.copy(tokens) @@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} - self.vector_shuffle = shuffle_none + self.is_training = False + self.vector_shuffle = shuffle_auto + self.dropout = 0 + + def train(self): + self.is_training = True + + def eval(self): + self.is_training = False + + def set_dropout(self, dropout: float): + self.dropout = dropout def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): if algorithm == "leading": @@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, ids) def expand_id(self, id: int): - return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] + if id in self.token_map: + ids = self.token_map[id] + if self.is_training: + ids = dropout(self.vector_shuffle(ids), self.dropout) + return ids + else: + return [id] def expand_ids(self, ids: list[int]): return [ diff --git a/train_dreambooth.py b/train_dreambooth.py index 218018b..f26b7f5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -107,6 +107,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--vector_dropout", + type=int, + default=0.1, + help="Vector dropout probability.", + ) parser.add_argument( "--vector_shuffle", type=str, @@ -556,6 +562,8 @@ def main(): tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -826,6 +834,12 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def on_train(): + tokenizer.train() + + def on_eval(): + tokenizer.eval() + def loop(batch): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() @@ -898,8 +912,8 @@ def main(): train_dataloader, val_dataloader, loop, - on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), - on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + on_train=tokenizer.train, + on_eval=tokenizer.eval, ) lr_finder.run(end_lr=1e2) @@ -953,7 +967,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Epoch X / Y") + local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -976,7 +990,7 @@ def main(): text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) - tokenizer.set_use_vector_shuffle(args.vector_shuffle) + on_train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): @@ -1030,7 +1044,7 @@ def main(): unet.eval() text_encoder.eval() - tokenizer.set_use_vector_shuffle(False) + on_eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() diff --git a/train_ti.py b/train_ti.py index 102c0fa..cacbbc7 100644 --- a/train_ti.py +++ b/train_ti.py @@ -154,6 +154,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--vector_dropout", + type=int, + default=0.1, + help="Vector dropout probability.", + ) parser.add_argument( "--vector_shuffle", type=str, @@ -526,6 +532,8 @@ def main(): tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -777,6 +785,12 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def on_train(): + tokenizer.train() + + def on_eval(): + tokenizer.eval() + def loop(batch): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() @@ -850,8 +864,8 @@ def main(): train_dataloader, val_dataloader, loop, - on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), - on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + on_train=on_train, + on_eval=on_eval, ) lr_finder.run(end_lr=1e2) @@ -903,7 +917,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Epoch X / Y") + local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -922,7 +936,7 @@ def main(): local_progress_bar.reset() text_encoder.train() - tokenizer.set_use_vector_shuffle(args.vector_shuffle) + on_train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): @@ -963,7 +977,7 @@ def main(): accelerator.wait_for_everyone() text_encoder.eval() - tokenizer.set_use_vector_shuffle(False) + on_eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py @@ -58,7 +58,11 @@ class LRFinder(): losses = [] accs = [] - lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) + lr_scheduler = get_exponential_schedule( + self.optimizer, + end_lr, + num_epochs * min(num_train_batches, len(self.train_dataloader)) + ) steps = min(num_train_batches, len(self.train_dataloader)) steps += min(num_val_batches, len(self.val_dataloader)) @@ -90,6 +94,7 @@ class LRFinder(): self.accelerator.backward(loss) self.optimizer.step() + lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: @@ -109,8 +114,6 @@ class LRFinder(): progress_bar.update(1) - lr_scheduler.step() - loss = avg_loss.avg.item() acc = avg_acc.avg.item() -- cgit v1.2.3-70-g09d2