From 83725794618164210a12843381724252fdd82cc2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Dec 2022 18:08:36 +0100 Subject: Integrated updates from diffusers --- train_dreambooth.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 325fe90..202d52c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule from training.ti import patch_trainable_embeddings -from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -580,12 +580,10 @@ def main(): patch_trainable_embeddings(text_encoder, placeholder_token_id) - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - text_encoder.text_model.embeddings.token_embedding.parameters(), - )) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -905,9 +903,7 @@ def main(): if epoch < args.train_text_encoder_epochs: text_encoder.train() elif epoch == args.train_text_encoder_epochs: - freeze_params(text_encoder.parameters()) - - sample_checkpoint = False + text_encoder.requires_grad_(False) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): -- cgit v1.2.3-54-g00ecf