diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index 7d49782..e14aeea 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -72,7 +72,7 @@ def make_grid(images, rows, cols): | |||
72 | return grid | 72 | return grid |
73 | 73 | ||
74 | 74 | ||
75 | def get_models(pretrained_model_name_or_path: str): | 75 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): |
76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -81,7 +81,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
82 | pretrained_model_name_or_path, subfolder='scheduler') | 82 | pretrained_model_name_or_path, subfolder='scheduler') |
83 | 83 | ||
84 | embeddings = patch_managed_embeddings(text_encoder) | 84 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) |
85 | 85 | ||
86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
87 | 87 | ||