diff options
| -rw-r--r-- | data/csv.py | 47 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 4 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 32 | ||||
| -rw-r--r-- | train_dreambooth.py | 71 | ||||
| -rw-r--r-- | train_ti.py | 86 | ||||
| -rw-r--r-- | training/common.py | 55 |
6 files changed, 149 insertions, 146 deletions
diff --git a/data/csv.py b/data/csv.py index 9ad7dd6..f5fc8e6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,7 +1,7 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | import copy | 4 | from functools import partial |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable |
| 7 | 7 | ||
| @@ -99,6 +99,41 @@ def generate_buckets( | |||
| 99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
| 100 | 100 | ||
| 101 | 101 | ||
| 102 | def collate_fn( | ||
| 103 | num_class_images: int, | ||
| 104 | weight_dtype: torch.dtype, | ||
| 105 | prompt_processor: PromptProcessor, | ||
| 106 | examples | ||
| 107 | ): | ||
| 108 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
| 109 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
| 110 | |||
| 111 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
| 112 | pixel_values = [example["instance_images"] for example in examples] | ||
| 113 | |||
| 114 | # concat class and instance examples for prior preservation | ||
| 115 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 116 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
| 117 | pixel_values += [example["class_images"] for example in examples] | ||
| 118 | |||
| 119 | pixel_values = torch.stack(pixel_values) | ||
| 120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
| 121 | |||
| 122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 124 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
| 125 | |||
| 126 | batch = { | ||
| 127 | "prompt_ids": prompts.input_ids, | ||
| 128 | "nprompt_ids": nprompts.input_ids, | ||
| 129 | "input_ids": inputs.input_ids, | ||
| 130 | "pixel_values": pixel_values, | ||
| 131 | "attention_mask": inputs.attention_mask, | ||
| 132 | } | ||
| 133 | |||
| 134 | return batch | ||
| 135 | |||
| 136 | |||
| 102 | class VlpnDataItem(NamedTuple): | 137 | class VlpnDataItem(NamedTuple): |
| 103 | instance_image_path: Path | 138 | instance_image_path: Path |
| 104 | class_image_path: Path | 139 | class_image_path: Path |
| @@ -129,7 +164,7 @@ class VlpnDataModule(): | |||
| 129 | valid_set_repeat: int = 1, | 164 | valid_set_repeat: int = 1, |
| 130 | seed: Optional[int] = None, | 165 | seed: Optional[int] = None, |
| 131 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 166 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 132 | collate_fn=None, | 167 | dtype: torch.dtype = torch.float32, |
| 133 | num_workers: int = 0 | 168 | num_workers: int = 0 |
| 134 | ): | 169 | ): |
| 135 | super().__init__() | 170 | super().__init__() |
| @@ -158,9 +193,9 @@ class VlpnDataModule(): | |||
| 158 | self.valid_set_repeat = valid_set_repeat | 193 | self.valid_set_repeat = valid_set_repeat |
| 159 | self.seed = seed | 194 | self.seed = seed |
| 160 | self.filter = filter | 195 | self.filter = filter |
| 161 | self.collate_fn = collate_fn | ||
| 162 | self.num_workers = num_workers | 196 | self.num_workers = num_workers |
| 163 | self.batch_size = batch_size | 197 | self.batch_size = batch_size |
| 198 | self.dtype = dtype | ||
| 164 | 199 | ||
| 165 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 200 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
| 166 | image = template["image"] if "image" in template else "{}" | 201 | image = template["image"] if "image" in template else "{}" |
| @@ -254,14 +289,16 @@ class VlpnDataModule(): | |||
| 254 | size=self.size, interpolation=self.interpolation, | 289 | size=self.size, interpolation=self.interpolation, |
| 255 | ) | 290 | ) |
| 256 | 291 | ||
| 292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | ||
| 293 | |||
| 257 | self.train_dataloader = DataLoader( | 294 | self.train_dataloader = DataLoader( |
| 258 | train_dataset, | 295 | train_dataset, |
| 259 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 296 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
| 260 | ) | 297 | ) |
| 261 | 298 | ||
| 262 | self.val_dataloader = DataLoader( | 299 | self.val_dataloader = DataLoader( |
| 263 | val_dataset, | 300 | val_dataset, |
| 264 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 301 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
| 265 | ) | 302 | ) |
| 266 | 303 | ||
| 267 | 304 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 46b414b..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -99,12 +99,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 99 | 99 | ||
| 100 | return embeds | 100 | return embeds |
| 101 | 101 | ||
| 102 | def normalize(self, lambda_: float = 1.0): | 102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): |
| 103 | w = self.temp_token_embedding.weight | 103 | w = self.temp_token_embedding.weight |
| 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
| 105 | w[self.temp_token_ids] = F.normalize( | 105 | w[self.temp_token_ids] = F.normalize( |
| 106 | w[self.temp_token_ids, :], dim=-1 | 106 | w[self.temp_token_ids, :], dim=-1 |
| 107 | ) * (pre_norm + lambda_ * (0.4 - pre_norm)) | 107 | ) * (pre_norm + lambda_ * (target - pre_norm)) |
| 108 | 108 | ||
| 109 | def forward( | 109 | def forward( |
| 110 | self, | 110 | self, |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb300d1..6bc40e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -20,7 +20,7 @@ from diffusers import ( | |||
| 20 | PNDMScheduler, | 20 | PNDMScheduler, |
| 21 | ) | 21 | ) |
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.utils import logging | 23 | from diffusers.utils import logging, randn_tensor |
| 24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
| 26 | 26 | ||
| @@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 250 | 250 | ||
| 251 | return timesteps | 251 | return timesteps |
| 252 | 252 | ||
| 253 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | 253 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| 254 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 254 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| 255 | 255 | ||
| 256 | if isinstance(generator, list) and len(generator) != batch_size: | 256 | if isinstance(generator, list) and len(generator) != batch_size: |
| 257 | raise ValueError( | 257 | raise ValueError( |
| @@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 260 | ) | 260 | ) |
| 261 | 261 | ||
| 262 | if latents is None: | 262 | if latents is None: |
| 263 | rand_device = "cpu" if device.type == "mps" else device | 263 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| 264 | |||
| 265 | if isinstance(generator, list): | ||
| 266 | shape = (1,) + shape[1:] | ||
| 267 | latents = [ | ||
| 268 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) | ||
| 269 | for i in range(batch_size) | ||
| 270 | ] | ||
| 271 | latents = torch.cat(latents, dim=0).to(device) | ||
| 272 | else: | ||
| 273 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) | ||
| 274 | else: | 264 | else: |
| 275 | if latents.shape != shape: | 265 | latents = latents.to(device=device, dtype=dtype) |
| 276 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
| 277 | latents = latents.to(device) | ||
| 278 | 266 | ||
| 279 | # scale the initial noise by the standard deviation required by the scheduler | 267 | # scale the initial noise by the standard deviation required by the scheduler |
| 280 | latents = latents * self.scheduler.init_noise_sigma | 268 | latents = latents * self.scheduler.init_noise_sigma |
| 281 | 269 | ||
| 282 | return latents | 270 | return latents |
| 283 | 271 | ||
| 284 | def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | 272 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
| 285 | init_image = init_image.to(device=device, dtype=dtype) | 273 | init_image = init_image.to(device=device, dtype=dtype) |
| 286 | init_latent_dist = self.vae.encode(init_image).latent_dist | 274 | init_latent_dist = self.vae.encode(init_image).latent_dist |
| 287 | init_latents = init_latent_dist.sample(generator=generator) | 275 | init_latents = init_latent_dist.sample(generator=generator) |
| @@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 292 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 280 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
| 293 | ) | 281 | ) |
| 294 | else: | 282 | else: |
| 295 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) | 283 | init_latents = torch.cat([init_latents] * batch_size, dim=0) |
| 296 | 284 | ||
| 297 | # add noise to latents using the timesteps | 285 | # add noise to latents using the timesteps |
| 298 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 286 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) |
| @@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 430 | latents = self.prepare_latents_from_image( | 418 | latents = self.prepare_latents_from_image( |
| 431 | image, | 419 | image, |
| 432 | latent_timestep, | 420 | latent_timestep, |
| 433 | batch_size, | 421 | batch_size * num_images_per_prompt, |
| 434 | num_images_per_prompt, | ||
| 435 | text_embeddings.dtype, | 422 | text_embeddings.dtype, |
| 436 | device, | 423 | device, |
| 437 | generator | 424 | generator |
| 438 | ) | 425 | ) |
| 439 | else: | 426 | else: |
| 440 | latents = self.prepare_latents( | 427 | latents = self.prepare_latents( |
| 441 | batch_size, | 428 | batch_size * num_images_per_prompt, |
| 442 | num_images_per_prompt, | ||
| 443 | num_channels_latents, | 429 | num_channels_latents, |
| 444 | height, | 430 | height, |
| 445 | width, | 431 | width, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index ebcf802..da3a075 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -14,7 +14,6 @@ from accelerate import Accelerator | |||
| 14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
| 15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
| 18 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
| 19 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
| 20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| @@ -24,8 +23,7 @@ from slugify import slugify | |||
| 24 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from data.csv import VlpnDataModule, VlpnDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
| 27 | from training.common import loss_step, generate_class_images | 26 | from training.common import loss_step, generate_class_images, get_scheduler |
| 28 | from training.optimization import get_one_cycle_schedule | ||
| 29 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| 30 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 31 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
| @@ -750,35 +748,6 @@ def main(): | |||
| 750 | ) | 748 | ) |
| 751 | return cond3 and cond4 | 749 | return cond3 and cond4 |
| 752 | 750 | ||
| 753 | def collate_fn(examples): | ||
| 754 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
| 755 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
| 756 | |||
| 757 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
| 758 | pixel_values = [example["instance_images"] for example in examples] | ||
| 759 | |||
| 760 | # concat class and instance examples for prior preservation | ||
| 761 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 762 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
| 763 | pixel_values += [example["class_images"] for example in examples] | ||
| 764 | |||
| 765 | pixel_values = torch.stack(pixel_values) | ||
| 766 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
| 767 | |||
| 768 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 769 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 770 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
| 771 | |||
| 772 | batch = { | ||
| 773 | "prompt_ids": prompts.input_ids, | ||
| 774 | "nprompt_ids": nprompts.input_ids, | ||
| 775 | "input_ids": inputs.input_ids, | ||
| 776 | "pixel_values": pixel_values, | ||
| 777 | "attention_mask": inputs.attention_mask, | ||
| 778 | } | ||
| 779 | |||
| 780 | return batch | ||
| 781 | |||
| 782 | datamodule = VlpnDataModule( | 751 | datamodule = VlpnDataModule( |
| 783 | data_file=args.train_data_file, | 752 | data_file=args.train_data_file, |
| 784 | batch_size=args.train_batch_size, | 753 | batch_size=args.train_batch_size, |
| @@ -798,7 +767,7 @@ def main(): | |||
| 798 | num_workers=args.dataloader_num_workers, | 767 | num_workers=args.dataloader_num_workers, |
| 799 | seed=args.seed, | 768 | seed=args.seed, |
| 800 | filter=keyword_filter, | 769 | filter=keyword_filter, |
| 801 | collate_fn=collate_fn | 770 | dtype=weight_dtype |
| 802 | ) | 771 | ) |
| 803 | 772 | ||
| 804 | datamodule.prepare_data() | 773 | datamodule.prepare_data() |
| @@ -829,33 +798,23 @@ def main(): | |||
| 829 | overrode_max_train_steps = True | 798 | overrode_max_train_steps = True |
| 830 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 799 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 831 | 800 | ||
| 832 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 801 | if args.find_lr: |
| 833 | 802 | lr_scheduler = None | |
| 834 | if args.lr_scheduler == "one_cycle": | ||
| 835 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
| 836 | lr_scheduler = get_one_cycle_schedule( | ||
| 837 | optimizer=optimizer, | ||
| 838 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 839 | warmup=args.lr_warmup_func, | ||
| 840 | annealing=args.lr_annealing_func, | ||
| 841 | warmup_exp=args.lr_warmup_exp, | ||
| 842 | annealing_exp=args.lr_annealing_exp, | ||
| 843 | min_lr=lr_min_lr, | ||
| 844 | ) | ||
| 845 | elif args.lr_scheduler == "cosine_with_restarts": | ||
| 846 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
| 847 | optimizer=optimizer, | ||
| 848 | num_warmup_steps=warmup_steps, | ||
| 849 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 850 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
| 851 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
| 852 | ) | ||
| 853 | else: | 803 | else: |
| 854 | lr_scheduler = get_scheduler( | 804 | lr_scheduler = get_scheduler( |
| 855 | args.lr_scheduler, | 805 | args.lr_scheduler, |
| 856 | optimizer=optimizer, | 806 | optimizer=optimizer, |
| 857 | num_warmup_steps=warmup_steps, | 807 | min_lr=args.lr_min_lr, |
| 858 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 808 | lr=args.learning_rate, |
| 809 | warmup_func=args.lr_warmup_func, | ||
| 810 | annealing_func=args.lr_annealing_func, | ||
| 811 | warmup_exp=args.lr_warmup_exp, | ||
| 812 | annealing_exp=args.lr_annealing_exp, | ||
| 813 | cycles=args.lr_cycles, | ||
| 814 | warmup_epochs=args.lr_warmup_epochs, | ||
| 815 | max_train_steps=args.max_train_steps, | ||
| 816 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
| 817 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
| 859 | ) | 818 | ) |
| 860 | 819 | ||
| 861 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 820 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
diff --git a/train_ti.py b/train_ti.py index 9ec5cfb..3b7e3b1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -13,7 +13,6 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
| 17 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
| 18 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 19 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
| @@ -22,8 +21,7 @@ from slugify import slugify | |||
| 22 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
| 25 | from training.common import loss_step, generate_class_images | 24 | from training.common import loss_step, generate_class_images, get_scheduler |
| 26 | from training.optimization import get_one_cycle_schedule | ||
| 27 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| 29 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
| @@ -410,10 +408,16 @@ def parse_args(): | |||
| 410 | help="The weight of prior preservation loss." | 408 | help="The weight of prior preservation loss." |
| 411 | ) | 409 | ) |
| 412 | parser.add_argument( | 410 | parser.add_argument( |
| 413 | "--max_grad_norm", | 411 | "--decay_target", |
| 414 | default=3.0, | 412 | default=0.4, |
| 415 | type=float, | 413 | type=float, |
| 416 | help="Max gradient norm." | 414 | help="Embedding decay target." |
| 415 | ) | ||
| 416 | parser.add_argument( | ||
| 417 | "--decay_factor", | ||
| 418 | default=100, | ||
| 419 | type=float, | ||
| 420 | help="Embedding decay factor." | ||
| 417 | ) | 421 | ) |
| 418 | parser.add_argument( | 422 | parser.add_argument( |
| 419 | "--noise_timesteps", | 423 | "--noise_timesteps", |
| @@ -709,35 +713,6 @@ def main(): | |||
| 709 | ) | 713 | ) |
| 710 | return cond1 and cond3 and cond4 | 714 | return cond1 and cond3 and cond4 |
| 711 | 715 | ||
| 712 | def collate_fn(examples): | ||
| 713 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
| 714 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
| 715 | |||
| 716 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
| 717 | pixel_values = [example["instance_images"] for example in examples] | ||
| 718 | |||
| 719 | # concat class and instance examples for prior preservation | ||
| 720 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 721 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
| 722 | pixel_values += [example["class_images"] for example in examples] | ||
| 723 | |||
| 724 | pixel_values = torch.stack(pixel_values) | ||
| 725 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
| 726 | |||
| 727 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 728 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 729 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
| 730 | |||
| 731 | batch = { | ||
| 732 | "prompt_ids": prompts.input_ids, | ||
| 733 | "nprompt_ids": nprompts.input_ids, | ||
| 734 | "input_ids": inputs.input_ids, | ||
| 735 | "pixel_values": pixel_values, | ||
| 736 | "attention_mask": inputs.attention_mask, | ||
| 737 | } | ||
| 738 | |||
| 739 | return batch | ||
| 740 | |||
| 741 | datamodule = VlpnDataModule( | 716 | datamodule = VlpnDataModule( |
| 742 | data_file=args.train_data_file, | 717 | data_file=args.train_data_file, |
| 743 | batch_size=args.train_batch_size, | 718 | batch_size=args.train_batch_size, |
| @@ -757,7 +732,7 @@ def main(): | |||
| 757 | num_workers=args.dataloader_num_workers, | 732 | num_workers=args.dataloader_num_workers, |
| 758 | seed=args.seed, | 733 | seed=args.seed, |
| 759 | filter=keyword_filter, | 734 | filter=keyword_filter, |
| 760 | collate_fn=collate_fn | 735 | dtype=weight_dtype |
| 761 | ) | 736 | ) |
| 762 | datamodule.setup() | 737 | datamodule.setup() |
| 763 | 738 | ||
| @@ -786,35 +761,23 @@ def main(): | |||
| 786 | overrode_max_train_steps = True | 761 | overrode_max_train_steps = True |
| 787 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 762 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 788 | 763 | ||
| 789 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | ||
| 790 | |||
| 791 | if args.find_lr: | 764 | if args.find_lr: |
| 792 | lr_scheduler = None | 765 | lr_scheduler = None |
| 793 | elif args.lr_scheduler == "one_cycle": | ||
| 794 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
| 795 | lr_scheduler = get_one_cycle_schedule( | ||
| 796 | optimizer=optimizer, | ||
| 797 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 798 | warmup=args.lr_warmup_func, | ||
| 799 | annealing=args.lr_annealing_func, | ||
| 800 | warmup_exp=args.lr_warmup_exp, | ||
| 801 | annealing_exp=args.lr_annealing_exp, | ||
| 802 | min_lr=lr_min_lr, | ||
| 803 | ) | ||
| 804 | elif args.lr_scheduler == "cosine_with_restarts": | ||
| 805 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
| 806 | optimizer=optimizer, | ||
| 807 | num_warmup_steps=warmup_steps, | ||
| 808 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 809 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
| 810 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
| 811 | ) | ||
| 812 | else: | 766 | else: |
| 813 | lr_scheduler = get_scheduler( | 767 | lr_scheduler = get_scheduler( |
| 814 | args.lr_scheduler, | 768 | args.lr_scheduler, |
| 815 | optimizer=optimizer, | 769 | optimizer=optimizer, |
| 816 | num_warmup_steps=warmup_steps, | 770 | min_lr=args.lr_min_lr, |
| 817 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 771 | lr=args.learning_rate, |
| 772 | warmup_func=args.lr_warmup_func, | ||
| 773 | annealing_func=args.lr_annealing_func, | ||
| 774 | warmup_exp=args.lr_warmup_exp, | ||
| 775 | annealing_exp=args.lr_annealing_exp, | ||
| 776 | cycles=args.lr_cycles, | ||
| 777 | warmup_epochs=args.lr_warmup_epochs, | ||
| 778 | max_train_steps=args.max_train_steps, | ||
| 779 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
| 780 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
| 818 | ) | 781 | ) |
| 819 | 782 | ||
| 820 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 783 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| @@ -868,7 +831,10 @@ def main(): | |||
| 868 | 831 | ||
| 869 | @torch.no_grad() | 832 | @torch.no_grad() |
| 870 | def on_after_optimize(lr: float): | 833 | def on_after_optimize(lr: float): |
| 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 834 | text_encoder.text_model.embeddings.normalize( |
| 835 | args.decay_target, | ||
| 836 | min(1.0, args.decay_factor * lr) | ||
| 837 | ) | ||
| 872 | 838 | ||
| 873 | loop = partial( | 839 | loop = partial( |
| 874 | loss_step, | 840 | loss_step, |
diff --git a/training/common.py b/training/common.py index 0b2ae44..90cf910 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -1,10 +1,65 @@ | |||
| 1 | import math | ||
| 2 | |||
| 1 | import torch | 3 | import torch |
| 2 | import torch.nn.functional as F | 4 | import torch.nn.functional as F |
| 3 | 5 | ||
| 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | ||
| 5 | 8 | ||
| 6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 7 | 10 | ||
| 11 | from training.optimization import get_one_cycle_schedule | ||
| 12 | |||
| 13 | |||
| 14 | def get_scheduler( | ||
| 15 | id: str, | ||
| 16 | min_lr: float, | ||
| 17 | lr: float, | ||
| 18 | warmup_func: str, | ||
| 19 | annealing_func: str, | ||
| 20 | warmup_exp: int, | ||
| 21 | annealing_exp: int, | ||
| 22 | cycles: int, | ||
| 23 | warmup_epochs: int, | ||
| 24 | optimizer: torch.optim.Optimizer, | ||
| 25 | max_train_steps: int, | ||
| 26 | num_update_steps_per_epoch: int, | ||
| 27 | gradient_accumulation_steps: int, | ||
| 28 | ): | ||
| 29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | ||
| 30 | |||
| 31 | if id == "one_cycle": | ||
| 32 | min_lr = 0.04 if min_lr is None else min_lr / lr | ||
| 33 | |||
| 34 | lr_scheduler = get_one_cycle_schedule( | ||
| 35 | optimizer=optimizer, | ||
| 36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
| 37 | warmup=warmup_func, | ||
| 38 | annealing=annealing_func, | ||
| 39 | warmup_exp=warmup_exp, | ||
| 40 | annealing_exp=annealing_exp, | ||
| 41 | min_lr=min_lr, | ||
| 42 | ) | ||
| 43 | elif id == "cosine_with_restarts": | ||
| 44 | cycles = cycles if cycles is not None else math.ceil( | ||
| 45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | ||
| 46 | |||
| 47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
| 48 | optimizer=optimizer, | ||
| 49 | num_warmup_steps=warmup_steps, | ||
| 50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
| 51 | num_cycles=cycles, | ||
| 52 | ) | ||
| 53 | else: | ||
| 54 | lr_scheduler = get_scheduler_( | ||
| 55 | id, | ||
| 56 | optimizer=optimizer, | ||
| 57 | num_warmup_steps=warmup_steps, | ||
| 58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
| 59 | ) | ||
| 60 | |||
| 61 | return lr_scheduler | ||
| 62 | |||
| 8 | 63 | ||
| 9 | def generate_class_images( | 64 | def generate_class_images( |
| 10 | accelerator, | 65 | accelerator, |
