diff options
-rw-r--r-- | data/csv.py | 47 | ||||
-rw-r--r-- | models/clip/embeddings.py | 4 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 32 | ||||
-rw-r--r-- | train_dreambooth.py | 71 | ||||
-rw-r--r-- | train_ti.py | 86 | ||||
-rw-r--r-- | training/common.py | 55 |
6 files changed, 149 insertions, 146 deletions
diff --git a/data/csv.py b/data/csv.py index 9ad7dd6..f5fc8e6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,7 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | import copy | 4 | from functools import partial |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable |
7 | 7 | ||
@@ -99,6 +99,41 @@ def generate_buckets( | |||
99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
100 | 100 | ||
101 | 101 | ||
102 | def collate_fn( | ||
103 | num_class_images: int, | ||
104 | weight_dtype: torch.dtype, | ||
105 | prompt_processor: PromptProcessor, | ||
106 | examples | ||
107 | ): | ||
108 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
109 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
110 | |||
111 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
112 | pixel_values = [example["instance_images"] for example in examples] | ||
113 | |||
114 | # concat class and instance examples for prior preservation | ||
115 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
116 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
117 | pixel_values += [example["class_images"] for example in examples] | ||
118 | |||
119 | pixel_values = torch.stack(pixel_values) | ||
120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
121 | |||
122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
124 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
125 | |||
126 | batch = { | ||
127 | "prompt_ids": prompts.input_ids, | ||
128 | "nprompt_ids": nprompts.input_ids, | ||
129 | "input_ids": inputs.input_ids, | ||
130 | "pixel_values": pixel_values, | ||
131 | "attention_mask": inputs.attention_mask, | ||
132 | } | ||
133 | |||
134 | return batch | ||
135 | |||
136 | |||
102 | class VlpnDataItem(NamedTuple): | 137 | class VlpnDataItem(NamedTuple): |
103 | instance_image_path: Path | 138 | instance_image_path: Path |
104 | class_image_path: Path | 139 | class_image_path: Path |
@@ -129,7 +164,7 @@ class VlpnDataModule(): | |||
129 | valid_set_repeat: int = 1, | 164 | valid_set_repeat: int = 1, |
130 | seed: Optional[int] = None, | 165 | seed: Optional[int] = None, |
131 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 166 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
132 | collate_fn=None, | 167 | dtype: torch.dtype = torch.float32, |
133 | num_workers: int = 0 | 168 | num_workers: int = 0 |
134 | ): | 169 | ): |
135 | super().__init__() | 170 | super().__init__() |
@@ -158,9 +193,9 @@ class VlpnDataModule(): | |||
158 | self.valid_set_repeat = valid_set_repeat | 193 | self.valid_set_repeat = valid_set_repeat |
159 | self.seed = seed | 194 | self.seed = seed |
160 | self.filter = filter | 195 | self.filter = filter |
161 | self.collate_fn = collate_fn | ||
162 | self.num_workers = num_workers | 196 | self.num_workers = num_workers |
163 | self.batch_size = batch_size | 197 | self.batch_size = batch_size |
198 | self.dtype = dtype | ||
164 | 199 | ||
165 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 200 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
166 | image = template["image"] if "image" in template else "{}" | 201 | image = template["image"] if "image" in template else "{}" |
@@ -254,14 +289,16 @@ class VlpnDataModule(): | |||
254 | size=self.size, interpolation=self.interpolation, | 289 | size=self.size, interpolation=self.interpolation, |
255 | ) | 290 | ) |
256 | 291 | ||
292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | ||
293 | |||
257 | self.train_dataloader = DataLoader( | 294 | self.train_dataloader = DataLoader( |
258 | train_dataset, | 295 | train_dataset, |
259 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 296 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
260 | ) | 297 | ) |
261 | 298 | ||
262 | self.val_dataloader = DataLoader( | 299 | self.val_dataloader = DataLoader( |
263 | val_dataset, | 300 | val_dataset, |
264 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 301 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
265 | ) | 302 | ) |
266 | 303 | ||
267 | 304 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 46b414b..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -99,12 +99,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
99 | 99 | ||
100 | return embeds | 100 | return embeds |
101 | 101 | ||
102 | def normalize(self, lambda_: float = 1.0): | 102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): |
103 | w = self.temp_token_embedding.weight | 103 | w = self.temp_token_embedding.weight |
104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
105 | w[self.temp_token_ids] = F.normalize( | 105 | w[self.temp_token_ids] = F.normalize( |
106 | w[self.temp_token_ids, :], dim=-1 | 106 | w[self.temp_token_ids, :], dim=-1 |
107 | ) * (pre_norm + lambda_ * (0.4 - pre_norm)) | 107 | ) * (pre_norm + lambda_ * (target - pre_norm)) |
108 | 108 | ||
109 | def forward( | 109 | def forward( |
110 | self, | 110 | self, |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb300d1..6bc40e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -20,7 +20,7 @@ from diffusers import ( | |||
20 | PNDMScheduler, | 20 | PNDMScheduler, |
21 | ) | 21 | ) |
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.utils import logging | 23 | from diffusers.utils import logging, randn_tensor |
24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
26 | 26 | ||
@@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
250 | 250 | ||
251 | return timesteps | 251 | return timesteps |
252 | 252 | ||
253 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | 253 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
254 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 254 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
255 | 255 | ||
256 | if isinstance(generator, list) and len(generator) != batch_size: | 256 | if isinstance(generator, list) and len(generator) != batch_size: |
257 | raise ValueError( | 257 | raise ValueError( |
@@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
260 | ) | 260 | ) |
261 | 261 | ||
262 | if latents is None: | 262 | if latents is None: |
263 | rand_device = "cpu" if device.type == "mps" else device | 263 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
264 | |||
265 | if isinstance(generator, list): | ||
266 | shape = (1,) + shape[1:] | ||
267 | latents = [ | ||
268 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) | ||
269 | for i in range(batch_size) | ||
270 | ] | ||
271 | latents = torch.cat(latents, dim=0).to(device) | ||
272 | else: | ||
273 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) | ||
274 | else: | 264 | else: |
275 | if latents.shape != shape: | 265 | latents = latents.to(device=device, dtype=dtype) |
276 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
277 | latents = latents.to(device) | ||
278 | 266 | ||
279 | # scale the initial noise by the standard deviation required by the scheduler | 267 | # scale the initial noise by the standard deviation required by the scheduler |
280 | latents = latents * self.scheduler.init_noise_sigma | 268 | latents = latents * self.scheduler.init_noise_sigma |
281 | 269 | ||
282 | return latents | 270 | return latents |
283 | 271 | ||
284 | def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | 272 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
285 | init_image = init_image.to(device=device, dtype=dtype) | 273 | init_image = init_image.to(device=device, dtype=dtype) |
286 | init_latent_dist = self.vae.encode(init_image).latent_dist | 274 | init_latent_dist = self.vae.encode(init_image).latent_dist |
287 | init_latents = init_latent_dist.sample(generator=generator) | 275 | init_latents = init_latent_dist.sample(generator=generator) |
@@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
292 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 280 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
293 | ) | 281 | ) |
294 | else: | 282 | else: |
295 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) | 283 | init_latents = torch.cat([init_latents] * batch_size, dim=0) |
296 | 284 | ||
297 | # add noise to latents using the timesteps | 285 | # add noise to latents using the timesteps |
298 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 286 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) |
@@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
430 | latents = self.prepare_latents_from_image( | 418 | latents = self.prepare_latents_from_image( |
431 | image, | 419 | image, |
432 | latent_timestep, | 420 | latent_timestep, |
433 | batch_size, | 421 | batch_size * num_images_per_prompt, |
434 | num_images_per_prompt, | ||
435 | text_embeddings.dtype, | 422 | text_embeddings.dtype, |
436 | device, | 423 | device, |
437 | generator | 424 | generator |
438 | ) | 425 | ) |
439 | else: | 426 | else: |
440 | latents = self.prepare_latents( | 427 | latents = self.prepare_latents( |
441 | batch_size, | 428 | batch_size * num_images_per_prompt, |
442 | num_images_per_prompt, | ||
443 | num_channels_latents, | 429 | num_channels_latents, |
444 | height, | 430 | height, |
445 | width, | 431 | width, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index ebcf802..da3a075 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -14,7 +14,6 @@ from accelerate import Accelerator | |||
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
18 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
19 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
@@ -24,8 +23,7 @@ from slugify import slugify | |||
24 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from data.csv import VlpnDataModule, VlpnDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
27 | from training.common import loss_step, generate_class_images | 26 | from training.common import loss_step, generate_class_images, get_scheduler |
28 | from training.optimization import get_one_cycle_schedule | ||
29 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
30 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
31 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
@@ -750,35 +748,6 @@ def main(): | |||
750 | ) | 748 | ) |
751 | return cond3 and cond4 | 749 | return cond3 and cond4 |
752 | 750 | ||
753 | def collate_fn(examples): | ||
754 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
755 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
756 | |||
757 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
758 | pixel_values = [example["instance_images"] for example in examples] | ||
759 | |||
760 | # concat class and instance examples for prior preservation | ||
761 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
762 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
763 | pixel_values += [example["class_images"] for example in examples] | ||
764 | |||
765 | pixel_values = torch.stack(pixel_values) | ||
766 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
767 | |||
768 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
769 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
770 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
771 | |||
772 | batch = { | ||
773 | "prompt_ids": prompts.input_ids, | ||
774 | "nprompt_ids": nprompts.input_ids, | ||
775 | "input_ids": inputs.input_ids, | ||
776 | "pixel_values": pixel_values, | ||
777 | "attention_mask": inputs.attention_mask, | ||
778 | } | ||
779 | |||
780 | return batch | ||
781 | |||
782 | datamodule = VlpnDataModule( | 751 | datamodule = VlpnDataModule( |
783 | data_file=args.train_data_file, | 752 | data_file=args.train_data_file, |
784 | batch_size=args.train_batch_size, | 753 | batch_size=args.train_batch_size, |
@@ -798,7 +767,7 @@ def main(): | |||
798 | num_workers=args.dataloader_num_workers, | 767 | num_workers=args.dataloader_num_workers, |
799 | seed=args.seed, | 768 | seed=args.seed, |
800 | filter=keyword_filter, | 769 | filter=keyword_filter, |
801 | collate_fn=collate_fn | 770 | dtype=weight_dtype |
802 | ) | 771 | ) |
803 | 772 | ||
804 | datamodule.prepare_data() | 773 | datamodule.prepare_data() |
@@ -829,33 +798,23 @@ def main(): | |||
829 | overrode_max_train_steps = True | 798 | overrode_max_train_steps = True |
830 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 799 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
831 | 800 | ||
832 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 801 | if args.find_lr: |
833 | 802 | lr_scheduler = None | |
834 | if args.lr_scheduler == "one_cycle": | ||
835 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
836 | lr_scheduler = get_one_cycle_schedule( | ||
837 | optimizer=optimizer, | ||
838 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
839 | warmup=args.lr_warmup_func, | ||
840 | annealing=args.lr_annealing_func, | ||
841 | warmup_exp=args.lr_warmup_exp, | ||
842 | annealing_exp=args.lr_annealing_exp, | ||
843 | min_lr=lr_min_lr, | ||
844 | ) | ||
845 | elif args.lr_scheduler == "cosine_with_restarts": | ||
846 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
847 | optimizer=optimizer, | ||
848 | num_warmup_steps=warmup_steps, | ||
849 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
850 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
851 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
852 | ) | ||
853 | else: | 803 | else: |
854 | lr_scheduler = get_scheduler( | 804 | lr_scheduler = get_scheduler( |
855 | args.lr_scheduler, | 805 | args.lr_scheduler, |
856 | optimizer=optimizer, | 806 | optimizer=optimizer, |
857 | num_warmup_steps=warmup_steps, | 807 | min_lr=args.lr_min_lr, |
858 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 808 | lr=args.learning_rate, |
809 | warmup_func=args.lr_warmup_func, | ||
810 | annealing_func=args.lr_annealing_func, | ||
811 | warmup_exp=args.lr_warmup_exp, | ||
812 | annealing_exp=args.lr_annealing_exp, | ||
813 | cycles=args.lr_cycles, | ||
814 | warmup_epochs=args.lr_warmup_epochs, | ||
815 | max_train_steps=args.max_train_steps, | ||
816 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
817 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
859 | ) | 818 | ) |
860 | 819 | ||
861 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 820 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
diff --git a/train_ti.py b/train_ti.py index 9ec5cfb..3b7e3b1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -13,7 +13,6 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
17 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
18 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
19 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
@@ -22,8 +21,7 @@ from slugify import slugify | |||
22 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
25 | from training.common import loss_step, generate_class_images | 24 | from training.common import loss_step, generate_class_images, get_scheduler |
26 | from training.optimization import get_one_cycle_schedule | ||
27 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
@@ -410,10 +408,16 @@ def parse_args(): | |||
410 | help="The weight of prior preservation loss." | 408 | help="The weight of prior preservation loss." |
411 | ) | 409 | ) |
412 | parser.add_argument( | 410 | parser.add_argument( |
413 | "--max_grad_norm", | 411 | "--decay_target", |
414 | default=3.0, | 412 | default=0.4, |
415 | type=float, | 413 | type=float, |
416 | help="Max gradient norm." | 414 | help="Embedding decay target." |
415 | ) | ||
416 | parser.add_argument( | ||
417 | "--decay_factor", | ||
418 | default=100, | ||
419 | type=float, | ||
420 | help="Embedding decay factor." | ||
417 | ) | 421 | ) |
418 | parser.add_argument( | 422 | parser.add_argument( |
419 | "--noise_timesteps", | 423 | "--noise_timesteps", |
@@ -709,35 +713,6 @@ def main(): | |||
709 | ) | 713 | ) |
710 | return cond1 and cond3 and cond4 | 714 | return cond1 and cond3 and cond4 |
711 | 715 | ||
712 | def collate_fn(examples): | ||
713 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
714 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
715 | |||
716 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
717 | pixel_values = [example["instance_images"] for example in examples] | ||
718 | |||
719 | # concat class and instance examples for prior preservation | ||
720 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
721 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
722 | pixel_values += [example["class_images"] for example in examples] | ||
723 | |||
724 | pixel_values = torch.stack(pixel_values) | ||
725 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
726 | |||
727 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
728 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
729 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
730 | |||
731 | batch = { | ||
732 | "prompt_ids": prompts.input_ids, | ||
733 | "nprompt_ids": nprompts.input_ids, | ||
734 | "input_ids": inputs.input_ids, | ||
735 | "pixel_values": pixel_values, | ||
736 | "attention_mask": inputs.attention_mask, | ||
737 | } | ||
738 | |||
739 | return batch | ||
740 | |||
741 | datamodule = VlpnDataModule( | 716 | datamodule = VlpnDataModule( |
742 | data_file=args.train_data_file, | 717 | data_file=args.train_data_file, |
743 | batch_size=args.train_batch_size, | 718 | batch_size=args.train_batch_size, |
@@ -757,7 +732,7 @@ def main(): | |||
757 | num_workers=args.dataloader_num_workers, | 732 | num_workers=args.dataloader_num_workers, |
758 | seed=args.seed, | 733 | seed=args.seed, |
759 | filter=keyword_filter, | 734 | filter=keyword_filter, |
760 | collate_fn=collate_fn | 735 | dtype=weight_dtype |
761 | ) | 736 | ) |
762 | datamodule.setup() | 737 | datamodule.setup() |
763 | 738 | ||
@@ -786,35 +761,23 @@ def main(): | |||
786 | overrode_max_train_steps = True | 761 | overrode_max_train_steps = True |
787 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 762 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
788 | 763 | ||
789 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | ||
790 | |||
791 | if args.find_lr: | 764 | if args.find_lr: |
792 | lr_scheduler = None | 765 | lr_scheduler = None |
793 | elif args.lr_scheduler == "one_cycle": | ||
794 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
795 | lr_scheduler = get_one_cycle_schedule( | ||
796 | optimizer=optimizer, | ||
797 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
798 | warmup=args.lr_warmup_func, | ||
799 | annealing=args.lr_annealing_func, | ||
800 | warmup_exp=args.lr_warmup_exp, | ||
801 | annealing_exp=args.lr_annealing_exp, | ||
802 | min_lr=lr_min_lr, | ||
803 | ) | ||
804 | elif args.lr_scheduler == "cosine_with_restarts": | ||
805 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
806 | optimizer=optimizer, | ||
807 | num_warmup_steps=warmup_steps, | ||
808 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
809 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
810 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
811 | ) | ||
812 | else: | 766 | else: |
813 | lr_scheduler = get_scheduler( | 767 | lr_scheduler = get_scheduler( |
814 | args.lr_scheduler, | 768 | args.lr_scheduler, |
815 | optimizer=optimizer, | 769 | optimizer=optimizer, |
816 | num_warmup_steps=warmup_steps, | 770 | min_lr=args.lr_min_lr, |
817 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 771 | lr=args.learning_rate, |
772 | warmup_func=args.lr_warmup_func, | ||
773 | annealing_func=args.lr_annealing_func, | ||
774 | warmup_exp=args.lr_warmup_exp, | ||
775 | annealing_exp=args.lr_annealing_exp, | ||
776 | cycles=args.lr_cycles, | ||
777 | warmup_epochs=args.lr_warmup_epochs, | ||
778 | max_train_steps=args.max_train_steps, | ||
779 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
780 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
818 | ) | 781 | ) |
819 | 782 | ||
820 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 783 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
@@ -868,7 +831,10 @@ def main(): | |||
868 | 831 | ||
869 | @torch.no_grad() | 832 | @torch.no_grad() |
870 | def on_after_optimize(lr: float): | 833 | def on_after_optimize(lr: float): |
871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 834 | text_encoder.text_model.embeddings.normalize( |
835 | args.decay_target, | ||
836 | min(1.0, args.decay_factor * lr) | ||
837 | ) | ||
872 | 838 | ||
873 | loop = partial( | 839 | loop = partial( |
874 | loss_step, | 840 | loss_step, |
diff --git a/training/common.py b/training/common.py index 0b2ae44..90cf910 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -1,10 +1,65 @@ | |||
1 | import math | ||
2 | |||
1 | import torch | 3 | import torch |
2 | import torch.nn.functional as F | 4 | import torch.nn.functional as F |
3 | 5 | ||
4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | ||
5 | 8 | ||
6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
7 | 10 | ||
11 | from training.optimization import get_one_cycle_schedule | ||
12 | |||
13 | |||
14 | def get_scheduler( | ||
15 | id: str, | ||
16 | min_lr: float, | ||
17 | lr: float, | ||
18 | warmup_func: str, | ||
19 | annealing_func: str, | ||
20 | warmup_exp: int, | ||
21 | annealing_exp: int, | ||
22 | cycles: int, | ||
23 | warmup_epochs: int, | ||
24 | optimizer: torch.optim.Optimizer, | ||
25 | max_train_steps: int, | ||
26 | num_update_steps_per_epoch: int, | ||
27 | gradient_accumulation_steps: int, | ||
28 | ): | ||
29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | ||
30 | |||
31 | if id == "one_cycle": | ||
32 | min_lr = 0.04 if min_lr is None else min_lr / lr | ||
33 | |||
34 | lr_scheduler = get_one_cycle_schedule( | ||
35 | optimizer=optimizer, | ||
36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
37 | warmup=warmup_func, | ||
38 | annealing=annealing_func, | ||
39 | warmup_exp=warmup_exp, | ||
40 | annealing_exp=annealing_exp, | ||
41 | min_lr=min_lr, | ||
42 | ) | ||
43 | elif id == "cosine_with_restarts": | ||
44 | cycles = cycles if cycles is not None else math.ceil( | ||
45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | ||
46 | |||
47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
48 | optimizer=optimizer, | ||
49 | num_warmup_steps=warmup_steps, | ||
50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
51 | num_cycles=cycles, | ||
52 | ) | ||
53 | else: | ||
54 | lr_scheduler = get_scheduler_( | ||
55 | id, | ||
56 | optimizer=optimizer, | ||
57 | num_warmup_steps=warmup_steps, | ||
58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
59 | ) | ||
60 | |||
61 | return lr_scheduler | ||
62 | |||
8 | 63 | ||
9 | def generate_class_images( | 64 | def generate_class_images( |
10 | accelerator, | 65 | accelerator, |