diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py index 4d0cf0e..85dd884 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
| 12 | 12 | ||
| 13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| 14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
| 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler | 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin |
| 16 | 16 | ||
| 17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | 18 | from PIL import Image |
| @@ -74,12 +74,19 @@ def make_grid(images, rows, cols): | |||
| 74 | return grid | 74 | return grid |
| 75 | 75 | ||
| 76 | 76 | ||
| 77 | def get_models(pretrained_model_name_or_path: str): | 77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): |
| 78 | if noise_scheduler == "deis": | ||
| 79 | noise_scheduler_cls = DEISMultistepScheduler | ||
| 80 | elif noise_scheduler == "ddpm": | ||
| 81 | noise_scheduler_cls = DDPMScheduler | ||
| 82 | else: | ||
| 83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
| 84 | |||
| 78 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 79 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| 81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
| 82 | noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
| 83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 84 | pretrained_model_name_or_path, subfolder='scheduler') | 91 | pretrained_model_name_or_path, subfolder='scheduler') |
| 85 | 92 | ||
| @@ -94,7 +101,7 @@ def save_samples( | |||
| 94 | text_encoder: CLIPTextModel, | 101 | text_encoder: CLIPTextModel, |
| 95 | tokenizer: MultiCLIPTokenizer, | 102 | tokenizer: MultiCLIPTokenizer, |
| 96 | vae: AutoencoderKL, | 103 | vae: AutoencoderKL, |
| 97 | sample_scheduler: UniPCMultistepScheduler, | 104 | sample_scheduler: SchedulerMixin, |
| 98 | train_dataloader: DataLoader, | 105 | train_dataloader: DataLoader, |
| 99 | val_dataloader: Optional[DataLoader], | 106 | val_dataloader: Optional[DataLoader], |
| 100 | output_dir: Path, | 107 | output_dir: Path, |
| @@ -181,7 +188,7 @@ def generate_class_images( | |||
| 181 | vae: AutoencoderKL, | 188 | vae: AutoencoderKL, |
| 182 | unet: UNet2DConditionModel, | 189 | unet: UNet2DConditionModel, |
| 183 | tokenizer: MultiCLIPTokenizer, | 190 | tokenizer: MultiCLIPTokenizer, |
| 184 | sample_scheduler: UniPCMultistepScheduler, | 191 | sample_scheduler: SchedulerMixin, |
| 185 | train_dataset: VlpnDataset, | 192 | train_dataset: VlpnDataset, |
| 186 | sample_batch_size: int, | 193 | sample_batch_size: int, |
| 187 | sample_image_size: int, | 194 | sample_image_size: int, |
| @@ -252,7 +259,7 @@ def add_placeholder_tokens( | |||
| 252 | 259 | ||
| 253 | def loss_step( | 260 | def loss_step( |
| 254 | vae: AutoencoderKL, | 261 | vae: AutoencoderKL, |
| 255 | noise_scheduler: DEISMultistepScheduler, | 262 | noise_scheduler: SchedulerMixin, |
| 256 | unet: UNet2DConditionModel, | 263 | unet: UNet2DConditionModel, |
| 257 | text_encoder: CLIPTextModel, | 264 | text_encoder: CLIPTextModel, |
| 258 | with_prior_preservation: bool, | 265 | with_prior_preservation: bool, |
| @@ -552,7 +559,7 @@ def train( | |||
| 552 | unet: UNet2DConditionModel, | 559 | unet: UNet2DConditionModel, |
| 553 | text_encoder: CLIPTextModel, | 560 | text_encoder: CLIPTextModel, |
| 554 | vae: AutoencoderKL, | 561 | vae: AutoencoderKL, |
| 555 | noise_scheduler: DEISMultistepScheduler, | 562 | noise_scheduler: SchedulerMixin, |
| 556 | dtype: torch.dtype, | 563 | dtype: torch.dtype, |
| 557 | seed: int, | 564 | seed: int, |
| 558 | project: str, | 565 | project: str, |
