From 83725794618164210a12843381724252fdd82cc2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Dec 2022 18:08:36 +0100 Subject: Integrated updates from diffusers --- train_ti.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 870b2ba..d7696e5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.ti import patch_trainable_embeddings -from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -533,12 +533,10 @@ def main(): patch_trainable_embeddings(text_encoder, placeholder_token_id) - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - text_encoder.text_model.embeddings.token_embedding.parameters(), - )) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -548,6 +546,9 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e2 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -715,7 +716,11 @@ def main(): # Keep vae and unet in eval mode as we don't train these vae.eval() - unet.eval() + + if args.gradient_checkpointing: + unet.train() + else: + unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) -- cgit v1.2.3-54-g00ecf