summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py47
-rw-r--r--models/clip/embeddings.py4
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py32
-rw-r--r--train_dreambooth.py71
-rw-r--r--train_ti.py86
-rw-r--r--training/common.py55
6 files changed, 149 insertions, 146 deletions
diff --git a/data/csv.py b/data/csv.py
index 9ad7dd6..f5fc8e6 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,7 +1,7 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import copy 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6from typing import NamedTuple, Optional, Union, Callable 6from typing import NamedTuple, Optional, Union, Callable
7 7
@@ -99,6 +99,41 @@ def generate_buckets(
99 return buckets, bucket_items, bucket_assignments 99 return buckets, bucket_items, bucket_assignments
100 100
101 101
102def collate_fn(
103 num_class_images: int,
104 weight_dtype: torch.dtype,
105 prompt_processor: PromptProcessor,
106 examples
107):
108 prompt_ids = [example["prompt_ids"] for example in examples]
109 nprompt_ids = [example["nprompt_ids"] for example in examples]
110
111 input_ids = [example["instance_prompt_ids"] for example in examples]
112 pixel_values = [example["instance_images"] for example in examples]
113
114 # concat class and instance examples for prior preservation
115 if num_class_images != 0 and "class_prompt_ids" in examples[0]:
116 input_ids += [example["class_prompt_ids"] for example in examples]
117 pixel_values += [example["class_images"] for example in examples]
118
119 pixel_values = torch.stack(pixel_values)
120 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
121
122 prompts = prompt_processor.unify_input_ids(prompt_ids)
123 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
124 inputs = prompt_processor.unify_input_ids(input_ids)
125
126 batch = {
127 "prompt_ids": prompts.input_ids,
128 "nprompt_ids": nprompts.input_ids,
129 "input_ids": inputs.input_ids,
130 "pixel_values": pixel_values,
131 "attention_mask": inputs.attention_mask,
132 }
133
134 return batch
135
136
102class VlpnDataItem(NamedTuple): 137class VlpnDataItem(NamedTuple):
103 instance_image_path: Path 138 instance_image_path: Path
104 class_image_path: Path 139 class_image_path: Path
@@ -129,7 +164,7 @@ class VlpnDataModule():
129 valid_set_repeat: int = 1, 164 valid_set_repeat: int = 1,
130 seed: Optional[int] = None, 165 seed: Optional[int] = None,
131 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 166 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
132 collate_fn=None, 167 dtype: torch.dtype = torch.float32,
133 num_workers: int = 0 168 num_workers: int = 0
134 ): 169 ):
135 super().__init__() 170 super().__init__()
@@ -158,9 +193,9 @@ class VlpnDataModule():
158 self.valid_set_repeat = valid_set_repeat 193 self.valid_set_repeat = valid_set_repeat
159 self.seed = seed 194 self.seed = seed
160 self.filter = filter 195 self.filter = filter
161 self.collate_fn = collate_fn
162 self.num_workers = num_workers 196 self.num_workers = num_workers
163 self.batch_size = batch_size 197 self.batch_size = batch_size
198 self.dtype = dtype
164 199
165 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 200 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
166 image = template["image"] if "image" in template else "{}" 201 image = template["image"] if "image" in template else "{}"
@@ -254,14 +289,16 @@ class VlpnDataModule():
254 size=self.size, interpolation=self.interpolation, 289 size=self.size, interpolation=self.interpolation,
255 ) 290 )
256 291
292 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor)
293
257 self.train_dataloader = DataLoader( 294 self.train_dataloader = DataLoader(
258 train_dataset, 295 train_dataset,
259 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 296 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers
260 ) 297 )
261 298
262 self.val_dataloader = DataLoader( 299 self.val_dataloader = DataLoader(
263 val_dataset, 300 val_dataset,
264 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 301 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers
265 ) 302 )
266 303
267 304
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 46b414b..9a23a2a 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -99,12 +99,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
99 99
100 return embeds 100 return embeds
101 101
102 def normalize(self, lambda_: float = 1.0): 102 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
103 w = self.temp_token_embedding.weight 103 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize( 105 w[self.temp_token_ids] = F.normalize(
106 w[self.temp_token_ids, :], dim=-1 106 w[self.temp_token_ids, :], dim=-1
107 ) * (pre_norm + lambda_ * (0.4 - pre_norm)) 107 ) * (pre_norm + lambda_ * (target - pre_norm))
108 108
109 def forward( 109 def forward(
110 self, 110 self,
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cb300d1..6bc40e9 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,7 +20,7 @@ from diffusers import (
20 PNDMScheduler, 20 PNDMScheduler,
21) 21)
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.utils import logging 23from diffusers.utils import logging, randn_tensor
24from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
25from models.clip.prompt import PromptProcessor 25from models.clip.prompt import PromptProcessor
26 26
@@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
250 250
251 return timesteps 251 return timesteps
252 252
253 def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): 253 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
254 shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) 254 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
255 255
256 if isinstance(generator, list) and len(generator) != batch_size: 256 if isinstance(generator, list) and len(generator) != batch_size:
257 raise ValueError( 257 raise ValueError(
@@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline):
260 ) 260 )
261 261
262 if latents is None: 262 if latents is None:
263 rand_device = "cpu" if device.type == "mps" else device 263 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
264
265 if isinstance(generator, list):
266 shape = (1,) + shape[1:]
267 latents = [
268 torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
269 for i in range(batch_size)
270 ]
271 latents = torch.cat(latents, dim=0).to(device)
272 else:
273 latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
274 else: 264 else:
275 if latents.shape != shape: 265 latents = latents.to(device=device, dtype=dtype)
276 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
277 latents = latents.to(device)
278 266
279 # scale the initial noise by the standard deviation required by the scheduler 267 # scale the initial noise by the standard deviation required by the scheduler
280 latents = latents * self.scheduler.init_noise_sigma 268 latents = latents * self.scheduler.init_noise_sigma
281 269
282 return latents 270 return latents
283 271
284 def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): 272 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None):
285 init_image = init_image.to(device=device, dtype=dtype) 273 init_image = init_image.to(device=device, dtype=dtype)
286 init_latent_dist = self.vae.encode(init_image).latent_dist 274 init_latent_dist = self.vae.encode(init_image).latent_dist
287 init_latents = init_latent_dist.sample(generator=generator) 275 init_latents = init_latent_dist.sample(generator=generator)
@@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
292 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 280 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
293 ) 281 )
294 else: 282 else:
295 init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) 283 init_latents = torch.cat([init_latents] * batch_size, dim=0)
296 284
297 # add noise to latents using the timesteps 285 # add noise to latents using the timesteps
298 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) 286 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
@@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
430 latents = self.prepare_latents_from_image( 418 latents = self.prepare_latents_from_image(
431 image, 419 image,
432 latent_timestep, 420 latent_timestep,
433 batch_size, 421 batch_size * num_images_per_prompt,
434 num_images_per_prompt,
435 text_embeddings.dtype, 422 text_embeddings.dtype,
436 device, 423 device,
437 generator 424 generator
438 ) 425 )
439 else: 426 else:
440 latents = self.prepare_latents( 427 latents = self.prepare_latents(
441 batch_size, 428 batch_size * num_images_per_prompt,
442 num_images_per_prompt,
443 num_channels_latents, 429 num_channels_latents,
444 height, 430 height,
445 width, 431 width,
diff --git a/train_dreambooth.py b/train_dreambooth.py
index ebcf802..da3a075 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -14,7 +14,6 @@ from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18import matplotlib.pyplot as plt 17import matplotlib.pyplot as plt
19from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
20from tqdm.auto import tqdm 19from tqdm.auto import tqdm
@@ -24,8 +23,7 @@ from slugify import slugify
24from util import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import VlpnDataModule, VlpnDataItem 25from data.csv import VlpnDataModule, VlpnDataItem
27from training.common import loss_step, generate_class_images 26from training.common import loss_step, generate_class_images, get_scheduler
28from training.optimization import get_one_cycle_schedule
29from training.lr import LRFinder 27from training.lr import LRFinder
30from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
31from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
@@ -750,35 +748,6 @@ def main():
750 ) 748 )
751 return cond3 and cond4 749 return cond3 and cond4
752 750
753 def collate_fn(examples):
754 prompt_ids = [example["prompt_ids"] for example in examples]
755 nprompt_ids = [example["nprompt_ids"] for example in examples]
756
757 input_ids = [example["instance_prompt_ids"] for example in examples]
758 pixel_values = [example["instance_images"] for example in examples]
759
760 # concat class and instance examples for prior preservation
761 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
762 input_ids += [example["class_prompt_ids"] for example in examples]
763 pixel_values += [example["class_images"] for example in examples]
764
765 pixel_values = torch.stack(pixel_values)
766 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
767
768 prompts = prompt_processor.unify_input_ids(prompt_ids)
769 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
770 inputs = prompt_processor.unify_input_ids(input_ids)
771
772 batch = {
773 "prompt_ids": prompts.input_ids,
774 "nprompt_ids": nprompts.input_ids,
775 "input_ids": inputs.input_ids,
776 "pixel_values": pixel_values,
777 "attention_mask": inputs.attention_mask,
778 }
779
780 return batch
781
782 datamodule = VlpnDataModule( 751 datamodule = VlpnDataModule(
783 data_file=args.train_data_file, 752 data_file=args.train_data_file,
784 batch_size=args.train_batch_size, 753 batch_size=args.train_batch_size,
@@ -798,7 +767,7 @@ def main():
798 num_workers=args.dataloader_num_workers, 767 num_workers=args.dataloader_num_workers,
799 seed=args.seed, 768 seed=args.seed,
800 filter=keyword_filter, 769 filter=keyword_filter,
801 collate_fn=collate_fn 770 dtype=weight_dtype
802 ) 771 )
803 772
804 datamodule.prepare_data() 773 datamodule.prepare_data()
@@ -829,33 +798,23 @@ def main():
829 overrode_max_train_steps = True 798 overrode_max_train_steps = True
830 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 799 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
831 800
832 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps 801 if args.find_lr:
833 802 lr_scheduler = None
834 if args.lr_scheduler == "one_cycle":
835 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
836 lr_scheduler = get_one_cycle_schedule(
837 optimizer=optimizer,
838 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
839 warmup=args.lr_warmup_func,
840 annealing=args.lr_annealing_func,
841 warmup_exp=args.lr_warmup_exp,
842 annealing_exp=args.lr_annealing_exp,
843 min_lr=lr_min_lr,
844 )
845 elif args.lr_scheduler == "cosine_with_restarts":
846 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
847 optimizer=optimizer,
848 num_warmup_steps=warmup_steps,
849 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
850 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
851 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
852 )
853 else: 803 else:
854 lr_scheduler = get_scheduler( 804 lr_scheduler = get_scheduler(
855 args.lr_scheduler, 805 args.lr_scheduler,
856 optimizer=optimizer, 806 optimizer=optimizer,
857 num_warmup_steps=warmup_steps, 807 min_lr=args.lr_min_lr,
858 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 808 lr=args.learning_rate,
809 warmup_func=args.lr_warmup_func,
810 annealing_func=args.lr_annealing_func,
811 warmup_exp=args.lr_warmup_exp,
812 annealing_exp=args.lr_annealing_exp,
813 cycles=args.lr_cycles,
814 warmup_epochs=args.lr_warmup_epochs,
815 max_train_steps=args.max_train_steps,
816 num_update_steps_per_epoch=num_update_steps_per_epoch,
817 gradient_accumulation_steps=args.gradient_accumulation_steps
859 ) 818 )
860 819
861 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 820 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
diff --git a/train_ti.py b/train_ti.py
index 9ec5cfb..3b7e3b1 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -13,7 +13,6 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt 16import matplotlib.pyplot as plt
18from tqdm.auto import tqdm 17from tqdm.auto import tqdm
19from transformers import CLIPTextModel 18from transformers import CLIPTextModel
@@ -22,8 +21,7 @@ from slugify import slugify
22from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
25from training.common import loss_step, generate_class_images 24from training.common import loss_step, generate_class_images, get_scheduler
26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 25from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args 26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
29from models.clip.embeddings import patch_managed_embeddings 27from models.clip.embeddings import patch_managed_embeddings
@@ -410,10 +408,16 @@ def parse_args():
410 help="The weight of prior preservation loss." 408 help="The weight of prior preservation loss."
411 ) 409 )
412 parser.add_argument( 410 parser.add_argument(
413 "--max_grad_norm", 411 "--decay_target",
414 default=3.0, 412 default=0.4,
415 type=float, 413 type=float,
416 help="Max gradient norm." 414 help="Embedding decay target."
415 )
416 parser.add_argument(
417 "--decay_factor",
418 default=100,
419 type=float,
420 help="Embedding decay factor."
417 ) 421 )
418 parser.add_argument( 422 parser.add_argument(
419 "--noise_timesteps", 423 "--noise_timesteps",
@@ -709,35 +713,6 @@ def main():
709 ) 713 )
710 return cond1 and cond3 and cond4 714 return cond1 and cond3 and cond4
711 715
712 def collate_fn(examples):
713 prompt_ids = [example["prompt_ids"] for example in examples]
714 nprompt_ids = [example["nprompt_ids"] for example in examples]
715
716 input_ids = [example["instance_prompt_ids"] for example in examples]
717 pixel_values = [example["instance_images"] for example in examples]
718
719 # concat class and instance examples for prior preservation
720 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
721 input_ids += [example["class_prompt_ids"] for example in examples]
722 pixel_values += [example["class_images"] for example in examples]
723
724 pixel_values = torch.stack(pixel_values)
725 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
726
727 prompts = prompt_processor.unify_input_ids(prompt_ids)
728 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
729 inputs = prompt_processor.unify_input_ids(input_ids)
730
731 batch = {
732 "prompt_ids": prompts.input_ids,
733 "nprompt_ids": nprompts.input_ids,
734 "input_ids": inputs.input_ids,
735 "pixel_values": pixel_values,
736 "attention_mask": inputs.attention_mask,
737 }
738
739 return batch
740
741 datamodule = VlpnDataModule( 716 datamodule = VlpnDataModule(
742 data_file=args.train_data_file, 717 data_file=args.train_data_file,
743 batch_size=args.train_batch_size, 718 batch_size=args.train_batch_size,
@@ -757,7 +732,7 @@ def main():
757 num_workers=args.dataloader_num_workers, 732 num_workers=args.dataloader_num_workers,
758 seed=args.seed, 733 seed=args.seed,
759 filter=keyword_filter, 734 filter=keyword_filter,
760 collate_fn=collate_fn 735 dtype=weight_dtype
761 ) 736 )
762 datamodule.setup() 737 datamodule.setup()
763 738
@@ -786,35 +761,23 @@ def main():
786 overrode_max_train_steps = True 761 overrode_max_train_steps = True
787 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 762 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
788 763
789 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
790
791 if args.find_lr: 764 if args.find_lr:
792 lr_scheduler = None 765 lr_scheduler = None
793 elif args.lr_scheduler == "one_cycle":
794 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
795 lr_scheduler = get_one_cycle_schedule(
796 optimizer=optimizer,
797 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
798 warmup=args.lr_warmup_func,
799 annealing=args.lr_annealing_func,
800 warmup_exp=args.lr_warmup_exp,
801 annealing_exp=args.lr_annealing_exp,
802 min_lr=lr_min_lr,
803 )
804 elif args.lr_scheduler == "cosine_with_restarts":
805 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
806 optimizer=optimizer,
807 num_warmup_steps=warmup_steps,
808 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
809 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
810 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
811 )
812 else: 766 else:
813 lr_scheduler = get_scheduler( 767 lr_scheduler = get_scheduler(
814 args.lr_scheduler, 768 args.lr_scheduler,
815 optimizer=optimizer, 769 optimizer=optimizer,
816 num_warmup_steps=warmup_steps, 770 min_lr=args.lr_min_lr,
817 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 771 lr=args.learning_rate,
772 warmup_func=args.lr_warmup_func,
773 annealing_func=args.lr_annealing_func,
774 warmup_exp=args.lr_warmup_exp,
775 annealing_exp=args.lr_annealing_exp,
776 cycles=args.lr_cycles,
777 warmup_epochs=args.lr_warmup_epochs,
778 max_train_steps=args.max_train_steps,
779 num_update_steps_per_epoch=num_update_steps_per_epoch,
780 gradient_accumulation_steps=args.gradient_accumulation_steps
818 ) 781 )
819 782
820 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 783 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
@@ -868,7 +831,10 @@ def main():
868 831
869 @torch.no_grad() 832 @torch.no_grad()
870 def on_after_optimize(lr: float): 833 def on_after_optimize(lr: float):
871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) 834 text_encoder.text_model.embeddings.normalize(
835 args.decay_target,
836 min(1.0, args.decay_factor * lr)
837 )
872 838
873 loop = partial( 839 loop = partial(
874 loss_step, 840 loss_step,
diff --git a/training/common.py b/training/common.py
index 0b2ae44..90cf910 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,10 +1,65 @@
1import math
2
1import torch 3import torch
2import torch.nn.functional as F 4import torch.nn.functional as F
3 5
4from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 6from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
7from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
5 8
6from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 9from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
7 10
11from training.optimization import get_one_cycle_schedule
12
13
14def get_scheduler(
15 id: str,
16 min_lr: float,
17 lr: float,
18 warmup_func: str,
19 annealing_func: str,
20 warmup_exp: int,
21 annealing_exp: int,
22 cycles: int,
23 warmup_epochs: int,
24 optimizer: torch.optim.Optimizer,
25 max_train_steps: int,
26 num_update_steps_per_epoch: int,
27 gradient_accumulation_steps: int,
28):
29 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps
30
31 if id == "one_cycle":
32 min_lr = 0.04 if min_lr is None else min_lr / lr
33
34 lr_scheduler = get_one_cycle_schedule(
35 optimizer=optimizer,
36 num_training_steps=max_train_steps * gradient_accumulation_steps,
37 warmup=warmup_func,
38 annealing=annealing_func,
39 warmup_exp=warmup_exp,
40 annealing_exp=annealing_exp,
41 min_lr=min_lr,
42 )
43 elif id == "cosine_with_restarts":
44 cycles = cycles if cycles is not None else math.ceil(
45 math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch)))
46
47 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
48 optimizer=optimizer,
49 num_warmup_steps=warmup_steps,
50 num_training_steps=max_train_steps * gradient_accumulation_steps,
51 num_cycles=cycles,
52 )
53 else:
54 lr_scheduler = get_scheduler_(
55 id,
56 optimizer=optimizer,
57 num_warmup_steps=warmup_steps,
58 num_training_steps=max_train_steps * gradient_accumulation_steps,
59 )
60
61 return lr_scheduler
62
8 63
9def generate_class_images( 64def generate_class_images(
10 accelerator, 65 accelerator,