From 89d471652644f449966a0cd944041c98dab7f66c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 07:25:24 +0100 Subject: Code deduplication --- train_ti.py | 86 +++++++++++++++++++------------------------------------------ 1 file changed, 26 insertions(+), 60 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 9ec5cfb..3b7e3b1 100644 --- a/train_ti.py +++ b/train_ti.py @@ -13,7 +13,6 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from tqdm.auto import tqdm from transformers import CLIPTextModel @@ -22,8 +21,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images -from training.optimization import get_one_cycle_schedule +from training.common import loss_step, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings @@ -410,10 +408,16 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--max_grad_norm", - default=3.0, + "--decay_target", + default=0.4, type=float, - help="Max gradient norm." + help="Embedding decay target." + ) + parser.add_argument( + "--decay_factor", + default=100, + type=float, + help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", @@ -709,35 +713,6 @@ def main(): ) return cond1 and cond3 and cond4 - def collate_fn(examples): - prompt_ids = [example["prompt_ids"] for example in examples] - nprompt_ids = [example["nprompt_ids"] for example in examples] - - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - # concat class and instance examples for prior preservation - if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - - prompts = prompt_processor.unify_input_ids(prompt_ids) - nprompts = prompt_processor.unify_input_ids(nprompt_ids) - inputs = prompt_processor.unify_input_ids(input_ids) - - batch = { - "prompt_ids": prompts.input_ids, - "nprompt_ids": nprompts.input_ids, - "input_ids": inputs.input_ids, - "pixel_values": pixel_values, - "attention_mask": inputs.attention_mask, - } - - return batch - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -757,7 +732,7 @@ def main(): num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, - collate_fn=collate_fn + dtype=weight_dtype ) datamodule.setup() @@ -786,35 +761,23 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps - if args.find_lr: lr_scheduler = None - elif args.lr_scheduler == "one_cycle": - lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate - lr_scheduler = get_one_cycle_schedule( - optimizer=optimizer, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - warmup=args.lr_warmup_func, - annealing=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - min_lr=lr_min_lr, - ) - elif args.lr_scheduler == "cosine_with_restarts": - lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles or math.ceil(math.sqrt( - ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), - ) else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + lr=args.learning_rate, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + warmup_epochs=args.lr_warmup_epochs, + max_train_steps=args.max_train_steps, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=args.gradient_accumulation_steps ) text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( @@ -868,7 +831,10 @@ def main(): @torch.no_grad() def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) + text_encoder.text_model.embeddings.normalize( + args.decay_target, + min(1.0, args.decay_factor * lr) + ) loop = partial( loss_step, -- cgit v1.2.3-54-g00ecf