From 3353ffb64c280a938a0f2513d13b716c1fca8c02 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 17:10:06 +0100 Subject: Cleanup --- train_ti.py | 60 ++++++++++++++++-------------------------------------------- 1 file changed, 16 insertions(+), 44 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 38c9755..b4b602b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -22,7 +22,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import run_model +from training.common import run_model, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args @@ -219,7 +219,6 @@ def parse_args(): parser.add_argument( "--scale_lr", action="store_true", - default=True, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( @@ -734,50 +733,23 @@ def main(): ) datamodule.setup() - if args.num_class_images != 0: - missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] - - if len(missing_data) != 0: - batched_data = [ - missing_data[i:i+args.sample_batch_size] - for i in range(0, len(missing_data), args.sample_batch_size) - ] - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=checkpoint_scheduler, - ).to(accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - with torch.inference_mode(): - for batch in batched_data: - image_name = [item.class_image_path for item in batch] - prompt = [item.cprompt for item in batch] - nprompt = [item.nprompt for item in batch] - - images = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=args.sample_image_size, - width=args.sample_image_size, - num_inference_steps=args.sample_steps - ).images - - for i, image in enumerate(images): - image.save(image_name[i]) - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - train_dataloaders = datamodule.train_dataloaders - default_train_dataloader = train_dataloaders[0] val_dataloader = datamodule.val_dataloader + if args.num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + checkpoint_scheduler, + datamodule.data_train, + args.sample_batch_size, + args.sample_image_size, + args.sample_steps + ) + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) @@ -898,7 +870,7 @@ def main(): accelerator, text_encoder, optimizer, - default_train_dataloader, + train_dataloaders[0], val_dataloader, loop, on_train=on_train, -- cgit v1.2.3-54-g00ecf