From 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 18 Feb 2023 13:00:13 +0100 Subject: Update --- training/functional.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 4d0cf0e..85dd884 100644 --- a/training/functional.py +++ b/training/functional.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler +from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin from tqdm.auto import tqdm from PIL import Image @@ -74,12 +74,19 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): + if noise_scheduler == "deis": + noise_scheduler_cls = DEISMultistepScheduler + elif noise_scheduler == "ddpm": + noise_scheduler_cls = DDPMScheduler + else: + raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') @@ -94,7 +101,7 @@ def save_samples( text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, - sample_scheduler: UniPCMultistepScheduler, + sample_scheduler: SchedulerMixin, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], output_dir: Path, @@ -181,7 +188,7 @@ def generate_class_images( vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, - sample_scheduler: UniPCMultistepScheduler, + sample_scheduler: SchedulerMixin, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, @@ -252,7 +259,7 @@ def add_placeholder_tokens( def loss_step( vae: AutoencoderKL, - noise_scheduler: DEISMultistepScheduler, + noise_scheduler: SchedulerMixin, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, with_prior_preservation: bool, @@ -552,7 +559,7 @@ def train( unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, - noise_scheduler: DEISMultistepScheduler, + noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, project: str, -- cgit v1.2.3-54-g00ecf