summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py20
-rw-r--r--training/sampler.py154
-rw-r--r--training/strategy/lora.py10
3 files changed, 171 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py
index e7e1eb3..eae5681 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator 26from models.convnext.discriminator import ConvNeXtDiscriminator
27from training.util import AverageMeter 27from training.util import AverageMeter
28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
28from util.slerp import slerp 29from util.slerp import slerp
29 30
30 31
@@ -318,6 +319,7 @@ def get_original(
318def loss_step( 319def loss_step(
319 vae: AutoencoderKL, 320 vae: AutoencoderKL,
320 noise_scheduler: SchedulerMixin, 321 noise_scheduler: SchedulerMixin,
322 schedule_sampler: ScheduleSampler,
321 unet: UNet2DConditionModel, 323 unet: UNet2DConditionModel,
322 text_encoder: CLIPTextModel, 324 text_encoder: CLIPTextModel,
323 guidance_scale: float, 325 guidance_scale: float,
@@ -362,14 +364,7 @@ def loss_step(
362 new_noise = noise + input_pertubation * torch.randn_like(noise) 364 new_noise = noise + input_pertubation * torch.randn_like(noise)
363 365
364 # Sample a random timestep for each image 366 # Sample a random timestep for each image
365 timesteps = torch.randint( 367 timesteps, weights = schedule_sampler.sample(bsz, latents.device)
366 0,
367 noise_scheduler.config.num_train_timesteps,
368 (bsz,),
369 generator=generator,
370 device=latents.device,
371 )
372 timesteps = timesteps.long()
373 368
374 # Add noise to the latents according to the noise magnitude at each timestep 369 # Add noise to the latents according to the noise magnitude at each timestep
375 # (this is the forward diffusion process) 370 # (this is the forward diffusion process)
@@ -443,6 +438,10 @@ def loss_step(
443 ) 438 )
444 loss = loss * mse_loss_weights 439 loss = loss * mse_loss_weights
445 440
441 if isinstance(schedule_sampler, LossAwareSampler):
442 schedule_sampler.update_with_all_losses(timesteps, loss.detach())
443
444 loss = loss * weights
446 loss = loss.mean() 445 loss = loss.mean()
447 446
448 return loss, acc, bsz 447 return loss, acc, bsz
@@ -694,6 +693,7 @@ def train(
694 offset_noise_strength: float = 0.01, 693 offset_noise_strength: float = 0.01,
695 input_pertubation: float = 0.1, 694 input_pertubation: float = 0.1,
696 disc: Optional[ConvNeXtDiscriminator] = None, 695 disc: Optional[ConvNeXtDiscriminator] = None,
696 schedule_sampler: Optional[ScheduleSampler] = None,
697 min_snr_gamma: int = 5, 697 min_snr_gamma: int = 5,
698 avg_loss: AverageMeter = AverageMeter(), 698 avg_loss: AverageMeter = AverageMeter(),
699 avg_acc: AverageMeter = AverageMeter(), 699 avg_acc: AverageMeter = AverageMeter(),
@@ -725,10 +725,14 @@ def train(
725 **kwargs, 725 **kwargs,
726 ) 726 )
727 727
728 if schedule_sampler is None:
729 schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps)
730
728 loss_step_ = partial( 731 loss_step_ = partial(
729 loss_step, 732 loss_step,
730 vae, 733 vae,
731 noise_scheduler, 734 noise_scheduler,
735 schedule_sampler,
732 unet, 736 unet,
733 text_encoder, 737 text_encoder,
734 guidance_scale, 738 guidance_scale,
diff --git a/training/sampler.py b/training/sampler.py
new file mode 100644
index 0000000..8afe255
--- /dev/null
+++ b/training/sampler.py
@@ -0,0 +1,154 @@
1from abc import ABC, abstractmethod
2
3import numpy as np
4import torch
5import torch.distributed as dist
6
7
8def create_named_schedule_sampler(name, num_timesteps):
9 """
10 Create a ScheduleSampler from a library of pre-defined samplers.
11
12 :param name: the name of the sampler.
13 :param diffusion: the diffusion object to sample for.
14 """
15 if name == "uniform":
16 return UniformSampler(num_timesteps)
17 elif name == "loss-second-moment":
18 return LossSecondMomentResampler(num_timesteps)
19 else:
20 raise NotImplementedError(f"unknown schedule sampler: {name}")
21
22
23class ScheduleSampler(ABC):
24 """
25 A distribution over timesteps in the diffusion process, intended to reduce
26 variance of the objective.
27
28 By default, samplers perform unbiased importance sampling, in which the
29 objective's mean is unchanged.
30 However, subclasses may override sample() to change how the resampled
31 terms are reweighted, allowing for actual changes in the objective.
32 """
33
34 @abstractmethod
35 def weights(self):
36 """
37 Get a numpy array of weights, one per diffusion step.
38
39 The weights needn't be normalized, but must be positive.
40 """
41
42 def sample(self, batch_size, device):
43 """
44 Importance-sample timesteps for a batch.
45
46 :param batch_size: the number of timesteps.
47 :param device: the torch device to save to.
48 :return: a tuple (timesteps, weights):
49 - timesteps: a tensor of timestep indices.
50 - weights: a tensor of weights to scale the resulting losses.
51 """
52 w = self.weights()
53 p = w / np.sum(w)
54 indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 indices = torch.from_numpy(indices_np).long().to(device)
56 weights_np = 1 / (len(p) * p[indices_np])
57 weights = torch.from_numpy(weights_np).float().to(device)
58 return indices, weights
59
60
61class UniformSampler(ScheduleSampler):
62 def __init__(self, num_timesteps):
63 self.num_timesteps = num_timesteps
64 self._weights = np.ones([num_timesteps])
65
66 def weights(self):
67 return self._weights
68
69
70class LossAwareSampler(ScheduleSampler):
71 def update_with_local_losses(self, local_ts, local_losses):
72 """
73 Update the reweighting using losses from a model.
74
75 Call this method from each rank with a batch of timesteps and the
76 corresponding losses for each of those timesteps.
77 This method will perform synchronization to make sure all of the ranks
78 maintain the exact same reweighting.
79
80 :param local_ts: an integer Tensor of timesteps.
81 :param local_losses: a 1D Tensor of losses.
82 """
83 batch_sizes = [
84 torch.tensor([0], dtype=torch.int32, device=local_ts.device)
85 for _ in range(dist.get_world_size())
86 ]
87 dist.all_gather(
88 batch_sizes,
89 torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device),
90 )
91
92 # Pad all_gather batches to be the maximum batch size.
93 batch_sizes = [x.item() for x in batch_sizes]
94 max_bs = max(batch_sizes)
95
96 timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 dist.all_gather(timestep_batches, local_ts)
99 dist.all_gather(loss_batches, local_losses)
100 timesteps = [
101 x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 ]
103 losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 self.update_with_all_losses(timesteps, losses)
105
106 @abstractmethod
107 def update_with_all_losses(self, ts, losses):
108 """
109 Update the reweighting using losses from a model.
110
111 Sub-classes should override this method to update the reweighting
112 using losses from the model.
113
114 This method directly updates the reweighting without synchronizing
115 between workers. It is called by update_with_local_losses from all
116 ranks with identical arguments. Thus, it should have deterministic
117 behavior to maintain state across workers.
118
119 :param ts: a list of int timesteps.
120 :param losses: a list of float losses, one per timestep.
121 """
122
123
124class LossSecondMomentResampler(LossAwareSampler):
125 def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001):
126 self.num_timesteps = num_timesteps
127 self.history_per_term = history_per_term
128 self.uniform_prob = uniform_prob
129 self._loss_history = np.zeros(
130 [self.num_timesteps, history_per_term], dtype=np.float64
131 )
132 self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int)
133
134 def weights(self):
135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights)
141 return weights
142
143 def update_with_all_losses(self, ts, losses):
144 for t, loss in zip(ts, losses):
145 if self._loss_counts[t] == self.history_per_term:
146 # Shift out the oldest loss term.
147 self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 self._loss_history[t, -1] = loss
149 else:
150 self._loss_history[t, self._loss_counts[t]] = loss
151 self._loss_counts[t] += 1
152
153 def _warmed_up(self):
154 return (self._loss_counts == self.history_per_term).all()
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 3f4dbbc..0c0f633 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -120,11 +120,11 @@ def lora_strategy_callbacks(
120 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 120 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
121 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 121 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
122 122
123 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 123 # for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
124 text_encoder_.text_model.embeddings.save_embed( 124 # text_encoder_.text_model.embeddings.save_embed(
125 ids, 125 # ids,
126 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 126 # checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
127 ) 127 # )
128 128
129 if not pti_mode: 129 if not pti_mode:
130 lora_config = {} 130 lora_config = {}