diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 1fdfdc8..2da0f69 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -68,9 +68,8 @@ class TrainingStrategy(): | |||
68 | 68 | ||
69 | def get_models( | 69 | def get_models( |
70 | pretrained_model_name_or_path: str, | 70 | pretrained_model_name_or_path: str, |
71 | emb_r: int = 8, | 71 | emb_alpha: int = 8, |
72 | emb_lora_alpha: int = 8, | 72 | emb_dropout: float = 0.0 |
73 | emb_lora_dropout: float = 0.0 | ||
74 | ): | 73 | ): |
75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 74 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 75 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -80,7 +79,7 @@ def get_models( | |||
80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 79 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
81 | pretrained_model_name_or_path, subfolder='scheduler') | 80 | pretrained_model_name_or_path, subfolder='scheduler') |
82 | 81 | ||
83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) | 82 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) |
84 | 83 | ||
85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 84 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
86 | 85 | ||