From 8f2b8e8d309470babd9b853fde8f0a081366deae Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 11 Dec 2022 12:59:13 +0100 Subject: Training improvements such as tag dropout --- data/csv.py | 13 +++++++-- dreambooth.py | 91 +++++++++++++++++++++++++++++++++++++---------------------- 2 files changed, 68 insertions(+), 36 deletions(-) diff --git a/data/csv.py b/data/csv.py index 23b5299..9125212 100644 --- a/data/csv.py +++ b/data/csv.py @@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def shuffle_prompt(prompt: str): +def shuffle_prompt(prompt: str, dropout: float = 0): def handle_block(block: str): words = block.split(", ") + words = [w for w in words if w != ""] + if dropout != 0: + words = [w for w in words if np.random.random() > dropout] np.random.shuffle(words) return ", ".join(words) prompt = prompt.split(". ") - prompt = [handle_block(b) for b in prompt] + prompt = [handle_block(b) for b in prompt if b != ""] np.random.shuffle(prompt) prompt = ". ".join(prompt) return prompt @@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule): num_class_images: int = 100, size: int = 512, repeats: int = 1, + dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, valid_set_size: Optional[int] = None, @@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule): self.class_identifier = class_identifier self.size = size self.repeats = repeats + self.dropout = dropout self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule): instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) + center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, @@ -153,6 +158,7 @@ class CSVDataset(Dataset): num_class_images: int = 0, size: int = 512, repeats: int = 1, + dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, ): @@ -163,6 +169,7 @@ class CSVDataset(Dataset): self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images + self.dropout = dropout self.image_cache = {} self.num_instance_images = len(self.data) diff --git a/dreambooth.py b/dreambooth.py index 1ef5156..1d6735f 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -88,6 +88,17 @@ def parse_args(): default=True, help="Whether to train the whole text encoder." ) + parser.add_argument( + "--train_text_encoder_epochs", + default=999999, + help="Number of epochs the text encoder will be trained." + ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0.1, + help="Tag dropout probability.", + ) parser.add_argument( "--num_class_images", type=int, @@ -185,9 +196,9 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", + "--lr_warmup_epochs", type=int, - default=500, + default=20, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -377,6 +388,20 @@ def make_grid(images, rows, cols): return grid +class AverageMeter: + def __init__(self, name=None): + self.name = name + self.reset() + + def reset(self): + self.sum = self.count = self.avg = 0 + + def update(self, val, n=1): + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + class Checkpointer: def __init__( self, @@ -744,6 +769,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + dropout=args.tag_dropout, center_crop=args.center_crop, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, @@ -802,6 +828,8 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps + if args.lr_scheduler == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, @@ -810,16 +838,16 @@ def main(): elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( - ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), + ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) @@ -863,11 +891,11 @@ def main(): global_step = 0 - total_loss = 0.0 - total_acc = 0.0 + avg_loss = AverageMeter() + avg_acc = AverageMeter() - total_loss_val = 0.0 - total_acc_val = 0.0 + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() max_acc_val = 0.0 @@ -913,7 +941,11 @@ def main(): local_progress_bar.reset() unet.train() - text_encoder.train() + + if epoch < args.train_text_encoder_epochs: + text_encoder.train() + elif epoch == args.train_text_encoder_epochs: + freeze_params(text_encoder.parameters()) sample_checkpoint = False @@ -980,7 +1012,7 @@ def main(): if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder + if args.train_text_encoder and epoch < args.train_text_encoder_epochs else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -992,11 +1024,10 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - acc = (model_pred == latents).float() - acc = acc.mean() + acc = (model_pred == latents).float().mean() - total_loss += loss.item() - total_acc += acc.item() + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1013,8 +1044,8 @@ def main(): sample_checkpoint = True logs = { - "train/loss": total_loss / global_step if global_step != 0 else 0, - "train/acc": total_acc / global_step if global_step != 0 else 0, + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), "lr/unet": lr_scheduler.get_last_lr()[0], @@ -1064,41 +1095,35 @@ def main(): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - acc = (model_pred == latents).float() - acc = acc.mean() + acc = (model_pred == latents).float().mean() - total_loss_val += loss.item() - total_acc_val += acc.item() + avg_loss_val.update(loss.detach_(), bsz) + avg_acc_val.update(acc.detach_(), bsz) if accelerator.sync_gradients: local_progress_bar.update(1) global_progress_bar.update(1) logs = { - "val/loss": total_loss_val / global_step, - "val/acc": total_acc_val / global_step, + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) - val_step = (epoch + 1) * len(val_dataloader) - avg_acc_val = total_acc_val / val_step - avg_loss_val = total_loss_val / val_step - accelerator.log({ - "val/loss": avg_loss_val, - "val/acc": avg_acc_val, + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() - if avg_acc_val > max_acc_val: + if avg_acc_val.avg.item() > max_acc_val: accelerator.print( - f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") - checkpointer.save_embedding(global_step, "milestone") - max_acc_val = avg_acc_val + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + max_acc_val = avg_acc_val.avg.item() if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples(global_step, args.sample_steps) -- cgit v1.2.3-54-g00ecf