diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index 96ecbc1..1d8e2ee 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols): | |||
73 | return grid | 73 | return grid |
74 | 74 | ||
75 | 75 | ||
76 | def get_models(pretrained_model_name_or_path: str): | 76 | def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): |
77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
83 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
84 | 84 | ||
85 | embeddings = patch_managed_embeddings(text_encoder) | 85 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha) |
86 | 86 | ||
87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
88 | 88 | ||