summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index 1fdfdc8..2da0f69 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -68,9 +68,8 @@ class TrainingStrategy():
68 68
69def get_models( 69def get_models(
70 pretrained_model_name_or_path: str, 70 pretrained_model_name_or_path: str,
71 emb_r: int = 8, 71 emb_alpha: int = 8,
72 emb_lora_alpha: int = 8, 72 emb_dropout: float = 0.0
73 emb_lora_dropout: float = 0.0
74): 73):
75 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 74 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
76 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 75 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
@@ -80,7 +79,7 @@ def get_models(
80 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 79 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
81 pretrained_model_name_or_path, subfolder='scheduler') 80 pretrained_model_name_or_path, subfolder='scheduler')
82 81
83 embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) 82 embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout)
84 83
85 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 84 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
86 85