diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -25,17 +25,31 @@ def const(result=None): | |||
25 | return fn | 25 | return fn |
26 | 26 | ||
27 | 27 | ||
28 | def get_models(pretrained_model_name_or_path: str): | ||
29 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
31 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
32 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
33 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
34 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
35 | pretrained_model_name_or_path, subfolder='scheduler') | ||
36 | |||
37 | embeddings = patch_managed_embeddings(text_encoder) | ||
38 | |||
39 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
40 | |||
41 | |||
28 | def generate_class_images( | 42 | def generate_class_images( |
29 | accelerator, | 43 | accelerator: Accelerator, |
30 | text_encoder, | 44 | text_encoder: CLIPTextModel, |
31 | vae, | 45 | vae: AutoencoderKL, |
32 | unet, | 46 | unet: UNet2DConditionModel, |
33 | tokenizer, | 47 | tokenizer: MultiCLIPTokenizer, |
34 | scheduler, | 48 | sample_scheduler: DPMSolverMultistepScheduler, |
35 | data_train, | 49 | data_train, |
36 | sample_batch_size, | 50 | sample_batch_size: int, |
37 | sample_image_size, | 51 | sample_image_size: int, |
38 | sample_steps | 52 | sample_steps: int |
39 | ): | 53 | ): |
40 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 54 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
41 | 55 | ||
@@ -52,7 +66,7 @@ def generate_class_images( | |||
52 | vae=vae, | 66 | vae=vae, |
53 | unet=unet, | 67 | unet=unet, |
54 | tokenizer=tokenizer, | 68 | tokenizer=tokenizer, |
55 | scheduler=scheduler, | 69 | scheduler=sample_scheduler, |
56 | ).to(accelerator.device) | 70 | ).to(accelerator.device) |
57 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 71 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
58 | 72 | ||