From 5571c4ebcb39813e2bd8585de30c64bb02f9d7fa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 09:43:22 +0100 Subject: Improved aspect ratio bucketing --- train_dreambooth.py | 100 +++++++++++++++++++++++----------------------------- 1 file changed, 45 insertions(+), 55 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 589af59..42a7d0f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -133,12 +133,6 @@ def parse_args(): default="cls", help="The directory where class images will be saved.", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data." - ) parser.add_argument( "--output_dir", type=str, @@ -738,7 +732,6 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, - repeats=args.repeats, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, @@ -751,7 +744,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - train_dataloaders = datamodule.train_dataloaders + train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: @@ -770,8 +763,7 @@ def main(): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) - num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -820,8 +812,7 @@ def main(): ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) - num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -877,7 +868,7 @@ def main(): accelerator, text_encoder, optimizer, - train_dataloaders[0], + train_dataloader, val_dataloader, loop, on_train=tokenizer.train, @@ -960,54 +951,53 @@ def main(): text_encoder.requires_grad_(False) with on_train(): - for train_dataloader in train_dataloaders: - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - loss, acc, bsz = loop(step, batch) - - accelerator.backward(loss) - - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder and epoch < args.train_text_encoder_epochs - else unet.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - if args.use_ema: - ema_unet.step(unet.parameters()) - optimizer.zero_grad(set_to_none=True) - - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - local_progress_bar.update(1) - global_progress_bar.update(1) + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + loss, acc, bsz = loop(step, batch) - global_step += 1 + accelerator.backward(loss) - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0] - } + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder and epoch < args.train_text_encoder_epochs + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() if args.use_ema: - logs["ema_decay"] = 1 - ema_unet.decay + ema_unet.step(unet.parameters()) + optimizer.zero_grad(set_to_none=True) - accelerator.log(logs, step=global_step) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - local_progress_bar.set_postfix(**logs) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) + + global_step += 1 + + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0] + } + if args.use_ema: + logs["ema_decay"] = 1 - ema_unet.decay + + accelerator.log(logs, step=global_step) + + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() -- cgit v1.2.3-54-g00ecf