From 8f2b8e8d309470babd9b853fde8f0a081366deae Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 11 Dec 2022 12:59:13 +0100 Subject: Training improvements such as tag dropout --- dreambooth.py | 91 +++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 58 insertions(+), 33 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 1ef5156..1d6735f 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -88,6 +88,17 @@ def parse_args(): default=True, help="Whether to train the whole text encoder." ) + parser.add_argument( + "--train_text_encoder_epochs", + default=999999, + help="Number of epochs the text encoder will be trained." + ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0.1, + help="Tag dropout probability.", + ) parser.add_argument( "--num_class_images", type=int, @@ -185,9 +196,9 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", + "--lr_warmup_epochs", type=int, - default=500, + default=20, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -377,6 +388,20 @@ def make_grid(images, rows, cols): return grid +class AverageMeter: + def __init__(self, name=None): + self.name = name + self.reset() + + def reset(self): + self.sum = self.count = self.avg = 0 + + def update(self, val, n=1): + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + class Checkpointer: def __init__( self, @@ -744,6 +769,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + dropout=args.tag_dropout, center_crop=args.center_crop, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, @@ -802,6 +828,8 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps + if args.lr_scheduler == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, @@ -810,16 +838,16 @@ def main(): elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( - ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), + ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) @@ -863,11 +891,11 @@ def main(): global_step = 0 - total_loss = 0.0 - total_acc = 0.0 + avg_loss = AverageMeter() + avg_acc = AverageMeter() - total_loss_val = 0.0 - total_acc_val = 0.0 + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() max_acc_val = 0.0 @@ -913,7 +941,11 @@ def main(): local_progress_bar.reset() unet.train() - text_encoder.train() + + if epoch < args.train_text_encoder_epochs: + text_encoder.train() + elif epoch == args.train_text_encoder_epochs: + freeze_params(text_encoder.parameters()) sample_checkpoint = False @@ -980,7 +1012,7 @@ def main(): if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder + if args.train_text_encoder and epoch < args.train_text_encoder_epochs else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -992,11 +1024,10 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - acc = (model_pred == latents).float() - acc = acc.mean() + acc = (model_pred == latents).float().mean() - total_loss += loss.item() - total_acc += acc.item() + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1013,8 +1044,8 @@ def main(): sample_checkpoint = True logs = { - "train/loss": total_loss / global_step if global_step != 0 else 0, - "train/acc": total_acc / global_step if global_step != 0 else 0, + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), "lr/unet": lr_scheduler.get_last_lr()[0], @@ -1064,41 +1095,35 @@ def main(): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - acc = (model_pred == latents).float() - acc = acc.mean() + acc = (model_pred == latents).float().mean() - total_loss_val += loss.item() - total_acc_val += acc.item() + avg_loss_val.update(loss.detach_(), bsz) + avg_acc_val.update(acc.detach_(), bsz) if accelerator.sync_gradients: local_progress_bar.update(1) global_progress_bar.update(1) logs = { - "val/loss": total_loss_val / global_step, - "val/acc": total_acc_val / global_step, + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) - val_step = (epoch + 1) * len(val_dataloader) - avg_acc_val = total_acc_val / val_step - avg_loss_val = total_loss_val / val_step - accelerator.log({ - "val/loss": avg_loss_val, - "val/acc": avg_acc_val, + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() - if avg_acc_val > max_acc_val: + if avg_acc_val.avg.item() > max_acc_val: accelerator.print( - f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") - checkpointer.save_embedding(global_step, "milestone") - max_acc_val = avg_acc_val + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + max_acc_val = avg_acc_val.avg.item() if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples(global_step, args.sample_steps) -- cgit v1.2.3-54-g00ecf