diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 1fdfdc8..2da0f69 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -68,9 +68,8 @@ class TrainingStrategy(): | |||
| 68 | 68 | ||
| 69 | def get_models( | 69 | def get_models( |
| 70 | pretrained_model_name_or_path: str, | 70 | pretrained_model_name_or_path: str, |
| 71 | emb_r: int = 8, | 71 | emb_alpha: int = 8, |
| 72 | emb_lora_alpha: int = 8, | 72 | emb_dropout: float = 0.0 |
| 73 | emb_lora_dropout: float = 0.0 | ||
| 74 | ): | 73 | ): |
| 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 74 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 75 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -80,7 +79,7 @@ def get_models( | |||
| 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 79 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 81 | pretrained_model_name_or_path, subfolder='scheduler') | 80 | pretrained_model_name_or_path, subfolder='scheduler') |
| 82 | 81 | ||
| 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) | 82 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) |
| 84 | 83 | ||
| 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 84 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 86 | 85 | ||
