From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- training/functional.py | 11 +++++++++-- training/strategy/lora.py | 4 ++-- training/strategy/ti.py | 9 ++++----- 3 files changed, 15 insertions(+), 9 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py @@ -66,7 +66,12 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): +def get_models( + pretrained_model_name_or_path: str, + emb_r: int = 8, + emb_lora_alpha: int = 8, + emb_lora_dropout: float = 0.0 +): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_dropout) + embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings @@ -653,6 +658,8 @@ def train_loop( on_checkpoint(global_step, "end") raise KeyboardInterrupt + return avg_loss, avg_acc, avg_loss_val, avg_acc_val + def train( accelerator: Accelerator, diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -93,7 +93,7 @@ def lora_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() + for p in text_encoder.text_model.embeddings.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -180,7 +180,7 @@ def lora_prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.token_override_embedding.parameters() + text_encoder.text_model.embeddings.token_embedding.parameters() ) else: return nullcontext() @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() + for p in text_encoder.text_model.embeddings.token_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -203,7 +203,6 @@ def textual_inversion_prepare( text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-54-g00ecf