summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-17 21:06:11 +0100
committerVolpeon <git@volpeon.ink>2023-02-17 21:06:11 +0100
commitf894dfecfaa3ec17903b2ac37ac4f071408613db (patch)
tree02bf8439315c832528651186285f8b1fbd649f32 /training
parentInference script: Better scheduler config (diff)
downloadtextual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.gz
textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.bz2
textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.zip
Added Lion optimizer
Diffstat (limited to 'training')
-rw-r--r--training/functional.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index 41794ea..4d0cf0e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler 15from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from schedulers.scheduling_deis_multistep import DEISMultistepScheduler
25from training.util import AverageMeter 26from training.util import AverageMeter
26 27
27 28
@@ -78,7 +79,7 @@ def get_models(pretrained_model_name_or_path: str):
78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 79 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 80 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
80 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 81 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
81 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 82 noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
82 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 83 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder='scheduler')
84 85
@@ -251,7 +252,7 @@ def add_placeholder_tokens(
251 252
252def loss_step( 253def loss_step(
253 vae: AutoencoderKL, 254 vae: AutoencoderKL,
254 noise_scheduler: DDPMScheduler, 255 noise_scheduler: DEISMultistepScheduler,
255 unet: UNet2DConditionModel, 256 unet: UNet2DConditionModel,
256 text_encoder: CLIPTextModel, 257 text_encoder: CLIPTextModel,
257 with_prior_preservation: bool, 258 with_prior_preservation: bool,
@@ -551,7 +552,7 @@ def train(
551 unet: UNet2DConditionModel, 552 unet: UNet2DConditionModel,
552 text_encoder: CLIPTextModel, 553 text_encoder: CLIPTextModel,
553 vae: AutoencoderKL, 554 vae: AutoencoderKL,
554 noise_scheduler: DDPMScheduler, 555 noise_scheduler: DEISMultistepScheduler,
555 dtype: torch.dtype, 556 dtype: torch.dtype,
556 seed: int, 557 seed: int,
557 project: str, 558 project: str,