From 83725794618164210a12843381724252fdd82cc2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Dec 2022 18:08:36 +0100 Subject: Integrated updates from diffusers --- train_lora.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index ffca304..9a42cae 100644 --- a/train_lora.py +++ b/train_lora.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.lora import LoraAttnProcessor from training.optimization import get_one_cycle_schedule -from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -513,11 +513,9 @@ def main(): print(f"Training added text embeddings") - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - )) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) index_fixed_tokens = torch.arange(len(tokenizer)) index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] -- cgit v1.2.3-54-g00ecf