summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py17
-rw-r--r--train_ti.py17
-rw-r--r--training/functional.py20
-rw-r--r--training/sampler.py154
-rw-r--r--training/strategy/lora.py10
5 files changed, 205 insertions, 13 deletions
diff --git a/train_lora.py b/train_lora.py
index cc7c1ec..70fbae4 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -27,6 +27,7 @@ from data.csv import VlpnDataModule, keyword_filter
27from training.functional import train, add_placeholder_tokens, get_models 27from training.functional import train, add_placeholder_tokens, get_models
28from training.strategy.lora import lora_strategy 28from training.strategy.lora import lora_strategy
29from training.optimization import get_scheduler 29from training.optimization import get_scheduler
30from training.sampler import create_named_schedule_sampler
30from training.util import AverageMeter, save_args 31from training.util import AverageMeter, save_args
31 32
32# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 33# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
@@ -410,6 +411,19 @@ def parse_args():
410 help="Minimum learning rate in the lr scheduler." 411 help="Minimum learning rate in the lr scheduler."
411 ) 412 )
412 parser.add_argument( 413 parser.add_argument(
414 "--min_snr_gamma",
415 type=int,
416 default=5,
417 help="MinSNR gamma."
418 )
419 parser.add_argument(
420 "--schedule_sampler",
421 type=str,
422 default="uniform",
423 choices=["uniform", "loss-second-moment"],
424 help="Noise schedule sampler."
425 )
426 parser.add_argument(
413 "--optimizer", 427 "--optimizer",
414 type=str, 428 type=str,
415 default="adan", 429 default="adan",
@@ -708,6 +722,7 @@ def main():
708 args.emb_alpha, 722 args.emb_alpha,
709 args.emb_dropout 723 args.emb_dropout
710 ) 724 )
725 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
711 726
712 unet_config = LoraConfig( 727 unet_config = LoraConfig(
713 r=args.lora_r, 728 r=args.lora_r,
@@ -923,6 +938,8 @@ def main():
923 tokenizer=tokenizer, 938 tokenizer=tokenizer,
924 vae=vae, 939 vae=vae,
925 noise_scheduler=noise_scheduler, 940 noise_scheduler=noise_scheduler,
941 schedule_sampler=schedule_sampler,
942 min_snr_gamma=args.min_snr_gamma,
926 dtype=weight_dtype, 943 dtype=weight_dtype,
927 seed=args.seed, 944 seed=args.seed,
928 compile_unet=args.compile_unet, 945 compile_unet=args.compile_unet,
diff --git a/train_ti.py b/train_ti.py
index ae73639..26f7941 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -23,6 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter
23from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
24from training.strategy.ti import textual_inversion_strategy 24from training.strategy.ti import textual_inversion_strategy
25from training.optimization import get_scheduler 25from training.optimization import get_scheduler
26from training.sampler import create_named_schedule_sampler
26from training.util import AverageMeter, save_args 27from training.util import AverageMeter, save_args
27 28
28logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -359,6 +360,19 @@ def parse_args():
359 default=0.9999 360 default=0.9999
360 ) 361 )
361 parser.add_argument( 362 parser.add_argument(
363 "--min_snr_gamma",
364 type=int,
365 default=5,
366 help="MinSNR gamma."
367 )
368 parser.add_argument(
369 "--schedule_sampler",
370 type=str,
371 default="uniform",
372 choices=["uniform", "loss-second-moment"],
373 help="Noise schedule sampler."
374 )
375 parser.add_argument(
362 "--optimizer", 376 "--optimizer",
363 type=str, 377 type=str,
364 default="adan", 378 default="adan",
@@ -682,6 +696,7 @@ def main():
682 args.emb_alpha, 696 args.emb_alpha,
683 args.emb_dropout 697 args.emb_dropout
684 ) 698 )
699 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
685 700
686 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 701 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
687 tokenizer.set_dropout(args.vector_dropout) 702 tokenizer.set_dropout(args.vector_dropout)
@@ -837,6 +852,8 @@ def main():
837 tokenizer=tokenizer, 852 tokenizer=tokenizer,
838 vae=vae, 853 vae=vae,
839 noise_scheduler=noise_scheduler, 854 noise_scheduler=noise_scheduler,
855 schedule_sampler=schedule_sampler,
856 min_snr_gamma=args.min_snr_gamma,
840 dtype=weight_dtype, 857 dtype=weight_dtype,
841 seed=args.seed, 858 seed=args.seed,
842 compile_unet=args.compile_unet, 859 compile_unet=args.compile_unet,
diff --git a/training/functional.py b/training/functional.py
index e7e1eb3..eae5681 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator 26from models.convnext.discriminator import ConvNeXtDiscriminator
27from training.util import AverageMeter 27from training.util import AverageMeter
28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
28from util.slerp import slerp 29from util.slerp import slerp
29 30
30 31
@@ -318,6 +319,7 @@ def get_original(
318def loss_step( 319def loss_step(
319 vae: AutoencoderKL, 320 vae: AutoencoderKL,
320 noise_scheduler: SchedulerMixin, 321 noise_scheduler: SchedulerMixin,
322 schedule_sampler: ScheduleSampler,
321 unet: UNet2DConditionModel, 323 unet: UNet2DConditionModel,
322 text_encoder: CLIPTextModel, 324 text_encoder: CLIPTextModel,
323 guidance_scale: float, 325 guidance_scale: float,
@@ -362,14 +364,7 @@ def loss_step(
362 new_noise = noise + input_pertubation * torch.randn_like(noise) 364 new_noise = noise + input_pertubation * torch.randn_like(noise)
363 365
364 # Sample a random timestep for each image 366 # Sample a random timestep for each image
365 timesteps = torch.randint( 367 timesteps, weights = schedule_sampler.sample(bsz, latents.device)
366 0,
367 noise_scheduler.config.num_train_timesteps,
368 (bsz,),
369 generator=generator,
370 device=latents.device,
371 )
372 timesteps = timesteps.long()
373 368
374 # Add noise to the latents according to the noise magnitude at each timestep 369 # Add noise to the latents according to the noise magnitude at each timestep
375 # (this is the forward diffusion process) 370 # (this is the forward diffusion process)
@@ -443,6 +438,10 @@ def loss_step(
443 ) 438 )
444 loss = loss * mse_loss_weights 439 loss = loss * mse_loss_weights
445 440
441 if isinstance(schedule_sampler, LossAwareSampler):
442 schedule_sampler.update_with_all_losses(timesteps, loss.detach())
443
444 loss = loss * weights
446 loss = loss.mean() 445 loss = loss.mean()
447 446
448 return loss, acc, bsz 447 return loss, acc, bsz
@@ -694,6 +693,7 @@ def train(
694 offset_noise_strength: float = 0.01, 693 offset_noise_strength: float = 0.01,
695 input_pertubation: float = 0.1, 694 input_pertubation: float = 0.1,
696 disc: Optional[ConvNeXtDiscriminator] = None, 695 disc: Optional[ConvNeXtDiscriminator] = None,
696 schedule_sampler: Optional[ScheduleSampler] = None,
697 min_snr_gamma: int = 5, 697 min_snr_gamma: int = 5,
698 avg_loss: AverageMeter = AverageMeter(), 698 avg_loss: AverageMeter = AverageMeter(),
699 avg_acc: AverageMeter = AverageMeter(), 699 avg_acc: AverageMeter = AverageMeter(),
@@ -725,10 +725,14 @@ def train(
725 **kwargs, 725 **kwargs,
726 ) 726 )
727 727
728 if schedule_sampler is None:
729 schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps)
730
728 loss_step_ = partial( 731 loss_step_ = partial(
729 loss_step, 732 loss_step,
730 vae, 733 vae,
731 noise_scheduler, 734 noise_scheduler,
735 schedule_sampler,
732 unet, 736 unet,
733 text_encoder, 737 text_encoder,
734 guidance_scale, 738 guidance_scale,
diff --git a/training/sampler.py b/training/sampler.py
new file mode 100644
index 0000000..8afe255
--- /dev/null
+++ b/training/sampler.py
@@ -0,0 +1,154 @@
1from abc import ABC, abstractmethod
2
3import numpy as np
4import torch
5import torch.distributed as dist
6
7
8def create_named_schedule_sampler(name, num_timesteps):
9 """
10 Create a ScheduleSampler from a library of pre-defined samplers.
11
12 :param name: the name of the sampler.
13 :param diffusion: the diffusion object to sample for.
14 """
15 if name == "uniform":
16 return UniformSampler(num_timesteps)
17 elif name == "loss-second-moment":
18 return LossSecondMomentResampler(num_timesteps)
19 else:
20 raise NotImplementedError(f"unknown schedule sampler: {name}")
21
22
23class ScheduleSampler(ABC):
24 """
25 A distribution over timesteps in the diffusion process, intended to reduce
26 variance of the objective.
27
28 By default, samplers perform unbiased importance sampling, in which the
29 objective's mean is unchanged.
30 However, subclasses may override sample() to change how the resampled
31 terms are reweighted, allowing for actual changes in the objective.
32 """
33
34 @abstractmethod
35 def weights(self):
36 """
37 Get a numpy array of weights, one per diffusion step.
38
39 The weights needn't be normalized, but must be positive.
40 """
41
42 def sample(self, batch_size, device):
43 """
44 Importance-sample timesteps for a batch.
45
46 :param batch_size: the number of timesteps.
47 :param device: the torch device to save to.
48 :return: a tuple (timesteps, weights):
49 - timesteps: a tensor of timestep indices.
50 - weights: a tensor of weights to scale the resulting losses.
51 """
52 w = self.weights()
53 p = w / np.sum(w)
54 indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 indices = torch.from_numpy(indices_np).long().to(device)
56 weights_np = 1 / (len(p) * p[indices_np])
57 weights = torch.from_numpy(weights_np).float().to(device)
58 return indices, weights
59
60
61class UniformSampler(ScheduleSampler):
62 def __init__(self, num_timesteps):
63 self.num_timesteps = num_timesteps
64 self._weights = np.ones([num_timesteps])
65
66 def weights(self):
67 return self._weights
68
69
70class LossAwareSampler(ScheduleSampler):
71 def update_with_local_losses(self, local_ts, local_losses):
72 """
73 Update the reweighting using losses from a model.
74
75 Call this method from each rank with a batch of timesteps and the
76 corresponding losses for each of those timesteps.
77 This method will perform synchronization to make sure all of the ranks
78 maintain the exact same reweighting.
79
80 :param local_ts: an integer Tensor of timesteps.
81 :param local_losses: a 1D Tensor of losses.
82 """
83 batch_sizes = [
84 torch.tensor([0], dtype=torch.int32, device=local_ts.device)
85 for _ in range(dist.get_world_size())
86 ]
87 dist.all_gather(
88 batch_sizes,
89 torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device),
90 )
91
92 # Pad all_gather batches to be the maximum batch size.
93 batch_sizes = [x.item() for x in batch_sizes]
94 max_bs = max(batch_sizes)
95
96 timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 dist.all_gather(timestep_batches, local_ts)
99 dist.all_gather(loss_batches, local_losses)
100 timesteps = [
101 x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 ]
103 losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 self.update_with_all_losses(timesteps, losses)
105
106 @abstractmethod
107 def update_with_all_losses(self, ts, losses):
108 """
109 Update the reweighting using losses from a model.
110
111 Sub-classes should override this method to update the reweighting
112 using losses from the model.
113
114 This method directly updates the reweighting without synchronizing
115 between workers. It is called by update_with_local_losses from all
116 ranks with identical arguments. Thus, it should have deterministic
117 behavior to maintain state across workers.
118
119 :param ts: a list of int timesteps.
120 :param losses: a list of float losses, one per timestep.
121 """
122
123
124class LossSecondMomentResampler(LossAwareSampler):
125 def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001):
126 self.num_timesteps = num_timesteps
127 self.history_per_term = history_per_term
128 self.uniform_prob = uniform_prob
129 self._loss_history = np.zeros(
130 [self.num_timesteps, history_per_term], dtype=np.float64
131 )
132 self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int)
133
134 def weights(self):
135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights)
141 return weights
142
143 def update_with_all_losses(self, ts, losses):
144 for t, loss in zip(ts, losses):
145 if self._loss_counts[t] == self.history_per_term:
146 # Shift out the oldest loss term.
147 self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 self._loss_history[t, -1] = loss
149 else:
150 self._loss_history[t, self._loss_counts[t]] = loss
151 self._loss_counts[t] += 1
152
153 def _warmed_up(self):
154 return (self._loss_counts == self.history_per_term).all()
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 3f4dbbc..0c0f633 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -120,11 +120,11 @@ def lora_strategy_callbacks(
120 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 120 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
121 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 121 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
122 122
123 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 123 # for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
124 text_encoder_.text_model.embeddings.save_embed( 124 # text_encoder_.text_model.embeddings.save_embed(
125 ids, 125 # ids,
126 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 126 # checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
127 ) 127 # )
128 128
129 if not pti_mode: 129 if not pti_mode:
130 lora_config = {} 130 lora_config = {}