diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -66,7 +66,12 @@ class TrainingStrategy(): | |||
| 66 | prepare: TrainingStrategyPrepareCallable | 66 | prepare: TrainingStrategyPrepareCallable |
| 67 | 67 | ||
| 68 | 68 | ||
| 69 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 69 | def get_models( |
| 70 | pretrained_model_name_or_path: str, | ||
| 71 | emb_r: int = 8, | ||
| 72 | emb_lora_alpha: int = 8, | ||
| 73 | emb_lora_dropout: float = 0.0 | ||
| 74 | ): | ||
| 70 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 72 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 77 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | |||
| 75 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 76 | pretrained_model_name_or_path, subfolder='scheduler') | 81 | pretrained_model_name_or_path, subfolder='scheduler') |
| 77 | 82 | ||
| 78 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) | 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) |
| 79 | 84 | ||
| 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 81 | 86 | ||
| @@ -653,6 +658,8 @@ def train_loop( | |||
| 653 | on_checkpoint(global_step, "end") | 658 | on_checkpoint(global_step, "end") |
| 654 | raise KeyboardInterrupt | 659 | raise KeyboardInterrupt |
| 655 | 660 | ||
| 661 | return avg_loss, avg_acc, avg_loss_val, avg_acc_val | ||
| 662 | |||
| 656 | 663 | ||
| 657 | def train( | 664 | def train( |
| 658 | accelerator: Accelerator, | 665 | accelerator: Accelerator, |
