From 3ee13893f9a4973ac75f45fe9318c35760dd4b1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 13:57:46 +0100 Subject: Added progressive aspect ratio bucketing --- train_ti.py | 94 ++++++++++++++++++++++++++++++------------------------------- 1 file changed, 46 insertions(+), 48 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 0ffc9e6..89c6672 100644 --- a/train_ti.py +++ b/train_ti.py @@ -21,7 +21,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import CSVDataModule, CSVDataItem +from data.csv import VlpnDataModule, VlpnDataItem from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder @@ -145,11 +145,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution" - ) parser.add_argument( "--tag_dropout", type=float, @@ -668,7 +663,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - def keyword_filter(item: CSVDataItem): + def keyword_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_token @@ -708,7 +703,7 @@ def main(): } return batch - datamodule = CSVDataModule( + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, @@ -717,7 +712,6 @@ def main(): size=args.resolution, repeats=args.repeats, dropout=args.tag_dropout, - center_crop=args.center_crop, template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, @@ -725,8 +719,6 @@ def main(): filter=keyword_filter, collate_fn=collate_fn ) - - datamodule.prepare_data() datamodule.setup() if args.num_class_images != 0: @@ -769,12 +761,14 @@ def main(): if torch.cuda.is_available(): torch.cuda.empty_cache() - train_dataloader = datamodule.train_dataloader() - val_dataloader = datamodule.val_dataloader() + train_dataloaders = datamodule.train_dataloaders + default_train_dataloader = train_dataloaders[0] + val_dataloader = datamodule.val_dataloader # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) + num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -811,9 +805,10 @@ def main(): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, val_dataloader, lr_scheduler ) + train_dataloaders = accelerator.prepare(*train_dataloaders) # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) @@ -831,7 +826,8 @@ def main(): unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) + num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -889,7 +885,7 @@ def main(): accelerator, text_encoder, optimizer, - train_dataloader, + default_train_dataloader, val_dataloader, loop, on_train=on_train, @@ -968,46 +964,48 @@ def main(): text_encoder.train() with on_train(): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(step, batch) + for train_dataloader in train_dataloaders: + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + loss, acc, bsz = loop(step, batch) - accelerator.backward(loss) + accelerator.backward(loss) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_embeddings.step( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - global_step += 1 + global_step += 1 - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - if args.use_ema: - logs["ema_decay"] = ema_embeddings.decay + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + if args.use_ema: + logs["ema_decay"] = ema_embeddings.decay - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=global_step) - local_progress_bar.set_postfix(**logs) + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() -- cgit v1.2.3-54-g00ecf