summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-18 13:00:13 +0100
committerVolpeon <git@volpeon.ink>2023-02-18 13:00:13 +0100
commit2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch)
treee08741c9df3b30a05ade472da45d7410bbf972ae /training/functional.py
parentAdded Lion optimizer (diff)
downloadtextual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz
textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2
textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py
index 4d0cf0e..85dd884 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler 15from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -74,12 +74,19 @@ def make_grid(images, rows, cols):
74 return grid 74 return grid
75 75
76 76
77def get_models(pretrained_model_name_or_path: str): 77def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"):
78 if noise_scheduler == "deis":
79 noise_scheduler_cls = DEISMultistepScheduler
80 elif noise_scheduler == "ddpm":
81 noise_scheduler_cls = DDPMScheduler
82 else:
83 raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}")
84
78 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 85 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
79 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 86 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
80 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 87 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
81 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 88 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
82 noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 89 noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
83 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 90 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
84 pretrained_model_name_or_path, subfolder='scheduler') 91 pretrained_model_name_or_path, subfolder='scheduler')
85 92
@@ -94,7 +101,7 @@ def save_samples(
94 text_encoder: CLIPTextModel, 101 text_encoder: CLIPTextModel,
95 tokenizer: MultiCLIPTokenizer, 102 tokenizer: MultiCLIPTokenizer,
96 vae: AutoencoderKL, 103 vae: AutoencoderKL,
97 sample_scheduler: UniPCMultistepScheduler, 104 sample_scheduler: SchedulerMixin,
98 train_dataloader: DataLoader, 105 train_dataloader: DataLoader,
99 val_dataloader: Optional[DataLoader], 106 val_dataloader: Optional[DataLoader],
100 output_dir: Path, 107 output_dir: Path,
@@ -181,7 +188,7 @@ def generate_class_images(
181 vae: AutoencoderKL, 188 vae: AutoencoderKL,
182 unet: UNet2DConditionModel, 189 unet: UNet2DConditionModel,
183 tokenizer: MultiCLIPTokenizer, 190 tokenizer: MultiCLIPTokenizer,
184 sample_scheduler: UniPCMultistepScheduler, 191 sample_scheduler: SchedulerMixin,
185 train_dataset: VlpnDataset, 192 train_dataset: VlpnDataset,
186 sample_batch_size: int, 193 sample_batch_size: int,
187 sample_image_size: int, 194 sample_image_size: int,
@@ -252,7 +259,7 @@ def add_placeholder_tokens(
252 259
253def loss_step( 260def loss_step(
254 vae: AutoencoderKL, 261 vae: AutoencoderKL,
255 noise_scheduler: DEISMultistepScheduler, 262 noise_scheduler: SchedulerMixin,
256 unet: UNet2DConditionModel, 263 unet: UNet2DConditionModel,
257 text_encoder: CLIPTextModel, 264 text_encoder: CLIPTextModel,
258 with_prior_preservation: bool, 265 with_prior_preservation: bool,
@@ -552,7 +559,7 @@ def train(
552 unet: UNet2DConditionModel, 559 unet: UNet2DConditionModel,
553 text_encoder: CLIPTextModel, 560 text_encoder: CLIPTextModel,
554 vae: AutoencoderKL, 561 vae: AutoencoderKL,
555 noise_scheduler: DEISMultistepScheduler, 562 noise_scheduler: SchedulerMixin,
556 dtype: torch.dtype, 563 dtype: torch.dtype,
557 seed: int, 564 seed: int,
558 project: str, 565 project: str,