summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py
index 54bbe78..1fdfdc8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -66,7 +66,12 @@ class TrainingStrategy():
66 prepare: TrainingStrategyPrepareCallable 66 prepare: TrainingStrategyPrepareCallable
67 67
68 68
69def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): 69def get_models(
70 pretrained_model_name_or_path: str,
71 emb_r: int = 8,
72 emb_lora_alpha: int = 8,
73 emb_lora_dropout: float = 0.0
74):
70 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 75 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
71 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 76 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
72 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 77 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0):
75 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 80 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
76 pretrained_model_name_or_path, subfolder='scheduler') 81 pretrained_model_name_or_path, subfolder='scheduler')
77 82
78 embeddings = patch_managed_embeddings(text_encoder, emb_dropout) 83 embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout)
79 84
80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 85 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
81 86
@@ -653,6 +658,8 @@ def train_loop(
653 on_checkpoint(global_step, "end") 658 on_checkpoint(global_step, "end")
654 raise KeyboardInterrupt 659 raise KeyboardInterrupt
655 660
661 return avg_loss, avg_acc, avg_loss_val, avg_acc_val
662
656 663
657def train( 664def train(
658 accelerator: Accelerator, 665 accelerator: Accelerator,