From f00877a13bce50b02cfc3790f2d18a325e9ff95b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:42:44 +0100 Subject: Update --- training/functional.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -25,17 +25,31 @@ def const(result=None): return fn +def get_models(pretrained_model_name_or_path: str): + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder='scheduler') + + embeddings = patch_managed_embeddings(text_encoder) + + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings + + def generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - scheduler, + accelerator: Accelerator, + text_encoder: CLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + sample_scheduler: DPMSolverMultistepScheduler, data_train, - sample_batch_size, - sample_image_size, - sample_steps + sample_batch_size: int, + sample_image_size: int, + sample_steps: int ): missing_data = [item for item in data_train if not item.class_image_path.exists()] @@ -52,7 +66,7 @@ def generate_class_images( vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) -- cgit v1.2.3-54-g00ecf