From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- training/functional.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py @@ -66,7 +66,12 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): +def get_models( + pretrained_model_name_or_path: str, + emb_r: int = 8, + emb_lora_alpha: int = 8, + emb_lora_dropout: float = 0.0 +): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_dropout) + embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings @@ -653,6 +658,8 @@ def train_loop( on_checkpoint(global_step, "end") raise KeyboardInterrupt + return avg_loss, avg_acc, avg_loss_val, avg_acc_val + def train( accelerator: Accelerator, -- cgit v1.2.3-54-g00ecf