From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- training/strategy/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -93,7 +93,7 @@ def lora_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() + for p in text_encoder.text_model.embeddings.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -180,7 +180,7 @@ def lora_prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-54-g00ecf