diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -66,7 +66,12 @@ class TrainingStrategy(): | |||
66 | prepare: TrainingStrategyPrepareCallable | 66 | prepare: TrainingStrategyPrepareCallable |
67 | 67 | ||
68 | 68 | ||
69 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 69 | def get_models( |
70 | pretrained_model_name_or_path: str, | ||
71 | emb_r: int = 8, | ||
72 | emb_lora_alpha: int = 8, | ||
73 | emb_lora_dropout: float = 0.0 | ||
74 | ): | ||
70 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
72 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 77 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | |||
75 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
76 | pretrained_model_name_or_path, subfolder='scheduler') | 81 | pretrained_model_name_or_path, subfolder='scheduler') |
77 | 82 | ||
78 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) | 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) |
79 | 84 | ||
80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
81 | 86 | ||
@@ -653,6 +658,8 @@ def train_loop( | |||
653 | on_checkpoint(global_step, "end") | 658 | on_checkpoint(global_step, "end") |
654 | raise KeyboardInterrupt | 659 | raise KeyboardInterrupt |
655 | 660 | ||
661 | return avg_loss, avg_acc, avg_loss_val, avg_acc_val | ||
662 | |||
656 | 663 | ||
657 | def train( | 664 | def train( |
658 | accelerator: Accelerator, | 665 | accelerator: Accelerator, |