From f894dfecfaa3ec17903b2ac37ac4f071408613db Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Feb 2023 21:06:11 +0100 Subject: Added Lion optimizer --- training/functional.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 41794ea..4d0cf0e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler +from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler from tqdm.auto import tqdm from PIL import Image @@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer +from schedulers.scheduling_deis_multistep import DEISMultistepScheduler from training.util import AverageMeter @@ -78,7 +79,7 @@ def get_models(pretrained_model_name_or_path: str): text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') @@ -251,7 +252,7 @@ def add_placeholder_tokens( def loss_step( vae: AutoencoderKL, - noise_scheduler: DDPMScheduler, + noise_scheduler: DEISMultistepScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, with_prior_preservation: bool, @@ -551,7 +552,7 @@ def train( unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, - noise_scheduler: DDPMScheduler, + noise_scheduler: DEISMultistepScheduler, dtype: torch.dtype, seed: int, project: str, -- cgit v1.2.3-54-g00ecf