diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
| commit | f00877a13bce50b02cfc3790f2d18a325e9ff95b (patch) | |
| tree | ebbda04024081e9c3c00400fae98124f3db2cc9c /training/functional.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2 textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip | |
Update
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -25,17 +25,31 @@ def const(result=None): | |||
| 25 | return fn | 25 | return fn |
| 26 | 26 | ||
| 27 | 27 | ||
| 28 | def get_models(pretrained_model_name_or_path: str): | ||
| 29 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 31 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 32 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 33 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 34 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 35 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 36 | |||
| 37 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 38 | |||
| 39 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 40 | |||
| 41 | |||
| 28 | def generate_class_images( | 42 | def generate_class_images( |
| 29 | accelerator, | 43 | accelerator: Accelerator, |
| 30 | text_encoder, | 44 | text_encoder: CLIPTextModel, |
| 31 | vae, | 45 | vae: AutoencoderKL, |
| 32 | unet, | 46 | unet: UNet2DConditionModel, |
| 33 | tokenizer, | 47 | tokenizer: MultiCLIPTokenizer, |
| 34 | scheduler, | 48 | sample_scheduler: DPMSolverMultistepScheduler, |
| 35 | data_train, | 49 | data_train, |
| 36 | sample_batch_size, | 50 | sample_batch_size: int, |
| 37 | sample_image_size, | 51 | sample_image_size: int, |
| 38 | sample_steps | 52 | sample_steps: int |
| 39 | ): | 53 | ): |
| 40 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 54 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
| 41 | 55 | ||
| @@ -52,7 +66,7 @@ def generate_class_images( | |||
| 52 | vae=vae, | 66 | vae=vae, |
| 53 | unet=unet, | 67 | unet=unet, |
| 54 | tokenizer=tokenizer, | 68 | tokenizer=tokenizer, |
| 55 | scheduler=scheduler, | 69 | scheduler=sample_scheduler, |
| 56 | ).to(accelerator.device) | 70 | ).to(accelerator.device) |
| 57 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 71 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 58 | 72 | ||
