diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 41794ea..4d0cf0e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
12 | 12 | ||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler | 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler |
16 | 16 | ||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
@@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from schedulers.scheduling_deis_multistep import DEISMultistepScheduler | ||
25 | from training.util import AverageMeter | 26 | from training.util import AverageMeter |
26 | 27 | ||
27 | 28 | ||
@@ -78,7 +79,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 79 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 82 | noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
83 | pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder='scheduler') |
84 | 85 | ||
@@ -251,7 +252,7 @@ def add_placeholder_tokens( | |||
251 | 252 | ||
252 | def loss_step( | 253 | def loss_step( |
253 | vae: AutoencoderKL, | 254 | vae: AutoencoderKL, |
254 | noise_scheduler: DDPMScheduler, | 255 | noise_scheduler: DEISMultistepScheduler, |
255 | unet: UNet2DConditionModel, | 256 | unet: UNet2DConditionModel, |
256 | text_encoder: CLIPTextModel, | 257 | text_encoder: CLIPTextModel, |
257 | with_prior_preservation: bool, | 258 | with_prior_preservation: bool, |
@@ -551,7 +552,7 @@ def train( | |||
551 | unet: UNet2DConditionModel, | 552 | unet: UNet2DConditionModel, |
552 | text_encoder: CLIPTextModel, | 553 | text_encoder: CLIPTextModel, |
553 | vae: AutoencoderKL, | 554 | vae: AutoencoderKL, |
554 | noise_scheduler: DDPMScheduler, | 555 | noise_scheduler: DEISMultistepScheduler, |
555 | dtype: torch.dtype, | 556 | dtype: torch.dtype, |
556 | seed: int, | 557 | seed: int, |
557 | project: str, | 558 | project: str, |