diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py index 4d0cf0e..85dd884 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
12 | 12 | ||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
15 | from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler | 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin |
16 | 16 | ||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
@@ -74,12 +74,19 @@ def make_grid(images, rows, cols): | |||
74 | return grid | 74 | return grid |
75 | 75 | ||
76 | 76 | ||
77 | def get_models(pretrained_model_name_or_path: str): | 77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): |
78 | if noise_scheduler == "deis": | ||
79 | noise_scheduler_cls = DEISMultistepScheduler | ||
80 | elif noise_scheduler == "ddpm": | ||
81 | noise_scheduler_cls = DDPMScheduler | ||
82 | else: | ||
83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
84 | |||
78 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
79 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
82 | noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
84 | pretrained_model_name_or_path, subfolder='scheduler') | 91 | pretrained_model_name_or_path, subfolder='scheduler') |
85 | 92 | ||
@@ -94,7 +101,7 @@ def save_samples( | |||
94 | text_encoder: CLIPTextModel, | 101 | text_encoder: CLIPTextModel, |
95 | tokenizer: MultiCLIPTokenizer, | 102 | tokenizer: MultiCLIPTokenizer, |
96 | vae: AutoencoderKL, | 103 | vae: AutoencoderKL, |
97 | sample_scheduler: UniPCMultistepScheduler, | 104 | sample_scheduler: SchedulerMixin, |
98 | train_dataloader: DataLoader, | 105 | train_dataloader: DataLoader, |
99 | val_dataloader: Optional[DataLoader], | 106 | val_dataloader: Optional[DataLoader], |
100 | output_dir: Path, | 107 | output_dir: Path, |
@@ -181,7 +188,7 @@ def generate_class_images( | |||
181 | vae: AutoencoderKL, | 188 | vae: AutoencoderKL, |
182 | unet: UNet2DConditionModel, | 189 | unet: UNet2DConditionModel, |
183 | tokenizer: MultiCLIPTokenizer, | 190 | tokenizer: MultiCLIPTokenizer, |
184 | sample_scheduler: UniPCMultistepScheduler, | 191 | sample_scheduler: SchedulerMixin, |
185 | train_dataset: VlpnDataset, | 192 | train_dataset: VlpnDataset, |
186 | sample_batch_size: int, | 193 | sample_batch_size: int, |
187 | sample_image_size: int, | 194 | sample_image_size: int, |
@@ -252,7 +259,7 @@ def add_placeholder_tokens( | |||
252 | 259 | ||
253 | def loss_step( | 260 | def loss_step( |
254 | vae: AutoencoderKL, | 261 | vae: AutoencoderKL, |
255 | noise_scheduler: DEISMultistepScheduler, | 262 | noise_scheduler: SchedulerMixin, |
256 | unet: UNet2DConditionModel, | 263 | unet: UNet2DConditionModel, |
257 | text_encoder: CLIPTextModel, | 264 | text_encoder: CLIPTextModel, |
258 | with_prior_preservation: bool, | 265 | with_prior_preservation: bool, |
@@ -552,7 +559,7 @@ def train( | |||
552 | unet: UNet2DConditionModel, | 559 | unet: UNet2DConditionModel, |
553 | text_encoder: CLIPTextModel, | 560 | text_encoder: CLIPTextModel, |
554 | vae: AutoencoderKL, | 561 | vae: AutoencoderKL, |
555 | noise_scheduler: DEISMultistepScheduler, | 562 | noise_scheduler: SchedulerMixin, |
556 | dtype: torch.dtype, | 563 | dtype: torch.dtype, |
557 | seed: int, | 564 | seed: int, |
558 | project: str, | 565 | project: str, |